"tests/git@developer.sourcefind.cn:SIYIXNI/vllm.git" did not exist on "96b6f475dda40a0c7d557f73c36fe09c07be2e9c"
Commit 08598523 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Define statically indexed array k_lds_windows[] to reduce the using of get_slice_tile()

parent 97efebdb
...@@ -281,6 +281,16 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -281,6 +281,16 @@ struct BlockFmhaPipelineQRKSVSAsync
k_tiles[I0] = load_tile(k_dram_window); k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
using k_lds_window_type =
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}));
statically_indexed_array<k_lds_window_type, NumKLdsBuffers> k_lds_windows;
static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_windows[i_buf] = get_slice_tile(
k_lds_window, sequence<i_buf * kN0, 0>{}, sequence<(i_buf + 1) * kN0, kK0>{});
});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
...@@ -309,9 +319,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -309,9 +319,7 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
if(num_total_loop > 1) // there are multiple iterations if(num_total_loop > 1) // there are multiple iterations
{ {
auto k_lds_window_tmp = store_tile(k_lds_windows[I0], k_tiles[I0]);
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[I0]);
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
...@@ -321,13 +329,12 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -321,13 +329,12 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds(); block_sync_lds();
// execute current unroll of gemm_0 // execute current unroll of gemm_0
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp); gemm_0(s_acc,
q_tiles[number<i_k0>{}],
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
k_lds_window_tmp = get_slice_tile( store_tile(k_lds_windows[number<(i_k0 + 1) % NumKLdsBuffers>{}],
k_lds_window, k_tiles[number<i_k0 + 1>{}]);
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0 + 1>{}]);
}); });
move_tile_window(k_dram_window, {kN0, -k0_loops * kK0}); move_tile_window(k_dram_window, {kN0, -k0_loops * kK0});
...@@ -343,13 +350,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -343,13 +350,13 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds(); block_sync_lds();
// execute last unroll of gemm_0 // execute last unroll of gemm_0
gemm_0(s_acc, q_tiles[number<k0_loops - 1>{}], k_lds_window_tmp); gemm_0(s_acc,
q_tiles[number<k0_loops - 1>{}],
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
} }
else // there is only single iteration else // there is only single iteration
{ {
auto k_lds_window_tmp = store_tile(k_lds_windows[I0], k_tiles[I0]);
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[I0]);
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
...@@ -362,15 +369,14 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -362,15 +369,14 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds(); block_sync_lds();
// execute current unroll of gemm_0 // execute current unroll of gemm_0
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp); gemm_0(s_acc,
q_tiles[number<i_k0>{}],
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
{ {
k_lds_window_tmp = get_slice_tile( store_tile(k_lds_windows[number<(i_k0 + 1) % NumKLdsBuffers>{}],
k_lds_window, k_tiles[number<i_k0 + 1>{}]);
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0 + 1>{}]);
}; };
}); });
...@@ -384,11 +390,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -384,11 +390,8 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {kN0, 0}); move_tile_window(k_dram_window, {kN0, 0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) { static_for<0, k0_loops, 1>{}([&](auto i_k0) {
auto k_lds_window_tmp = get_slice_tile( store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
k_lds_window, k_tiles[number<i_k0>{}]);
sequence<(i_k0 % NumKLdsBuffers) * kN0, 0>{},
sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0>{}]);
k_tiles[number<i_k0>{}] = load_tile(k_dram_window); k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
...@@ -398,7 +401,9 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -398,7 +401,9 @@ struct BlockFmhaPipelineQRKSVSAsync
clear_tile(s_acc); clear_tile(s_acc);
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp); gemm_0(s_acc,
q_tiles[number<i_k0>{}],
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
}); });
move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
...@@ -410,22 +415,23 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -410,22 +415,23 @@ struct BlockFmhaPipelineQRKSVSAsync
k_lds_window, k_lds_window,
sequence<(i_k0 % NumKLdsBuffers) * kN0, 0>{}, sequence<(i_k0 % NumKLdsBuffers) * kN0, 0>{},
sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{}); sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0>{}]); store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
k_tiles[number<i_k0>{}]);
if constexpr(i_k0 == 0) if constexpr(i_k0 == 0)
clear_tile(s_acc); clear_tile(s_acc);
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp); gemm_0(s_acc,
q_tiles[number<i_k0>{}],
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
}); });
}; };
}; };
} }
else else
{ {
auto k_lds_window_tmp = store_tile(k_lds_windows[I0], k_tiles[I0]);
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[I0]);
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
...@@ -438,15 +444,14 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -438,15 +444,14 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds(); block_sync_lds();
// execute current unroll of gemm_0 // execute current unroll of gemm_0
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp); gemm_0(s_acc,
q_tiles[number<i_k0>{}],
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
{ {
k_lds_window_tmp = get_slice_tile( store_tile(k_lds_windows[number<(i_k0 + 1) % NumKLdsBuffers>{}],
k_lds_window, k_tiles[number<(i_k0 + 1) % 2>{}]);
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<(i_k0 + 1) % 2>{}]);
}; };
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment