"...composable_kernel_rocm.git" did not exist on "2ab8bf4c12ba99854afc406ad24626080ee1acd1"
Commit 53766479 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Merge branch 'ck_tile/improve_async_pipeline' of...

Merge branch 'ck_tile/improve_async_pipeline' of https://github.com/ROCm/composable_kernel into ck_tile/improve_async_pipeline
parents 02b6c6c2 a151fd6d
......@@ -168,6 +168,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert(2 <= k0_loops);
static_assert(2 <= k1_loops);
constexpr bool kPreloadWholeNextIterationK = (kM0 <= 64);
constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
......@@ -210,11 +212,10 @@ struct BlockFmhaPipelineQRKSVSAsync
using k_tile_type = decltype(load_tile(k_dram_window));
auto k_tiles = [&]() {
// for hdim-96 and hdim-160, try to save vgprs
if constexpr(kQKHeaddim < kSubQKHeaddim)
return statically_indexed_array<k_tile_type, 2>{};
else
if constexpr(kPreloadWholeNextIterationK)
return statically_indexed_array<k_tile_type, k0_loops>{};
else
return statically_indexed_array<k_tile_type, 1>{};
}();
k_tiles[I0] = load_tile(k_dram_window);
......@@ -330,7 +331,7 @@ struct BlockFmhaPipelineQRKSVSAsync
do
{
if constexpr(kQKHeaddim == kSubQKHeaddim)
if constexpr(kPreloadWholeNextIterationK)
{
if(i_total_loops == 0) // executed by fist iteration
{
......@@ -341,7 +342,8 @@ struct BlockFmhaPipelineQRKSVSAsync
k_tiles[number<i_k0>{}]);
k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
if constexpr(i_k0 < k0_loops - 2)
move_tile_window(k_dram_window, {0, kK0});
if constexpr(i_k0 == 0)
clear_tile(s_acc);
......@@ -356,7 +358,7 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
k_tiles[number<k0_loops - 1>{}]);
move_tile_window(k_dram_window, {kN0, -k0_loops * kK0});
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
......@@ -375,15 +377,13 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else // there is only single iteration
{
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
k_tiles[number<i_k0>{}]);
if constexpr(i_k0 < k0_loops - 1)
{
k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
if constexpr(i_k0 < k0_loops - 2)
move_tile_window(k_dram_window, {0, kK0});
};
if constexpr(i_k0 == 0)
clear_tile(s_acc);
......@@ -395,6 +395,14 @@ struct BlockFmhaPipelineQRKSVSAsync
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
});
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
k_tiles[number<k0_loops - 1>{}]);
block_sync_lds();
gemm_0(s_acc,
q_tiles[number<k0_loops - 1>{}],
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
// move_tile_window(k_dram_window, {0, -k0_loops * kK0});
}
}
......@@ -440,21 +448,18 @@ struct BlockFmhaPipelineQRKSVSAsync
};
};
}
else // for hdim-96, hdim-160
else // only preload one unroll of K for next iteration
{
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
k_tiles[number<i_k0 % 2>{}]);
if constexpr(i_k0 < k0_loops - 1)
{
k_tiles[number<(i_k0 + 1) % 2>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}], k_tiles[I0]);
if constexpr(i_k0 == 0)
clear_tile(s_acc);
k_tiles[I0] = load_tile(k_dram_window);
if constexpr(i_k0 < k0_loops - 2)
move_tile_window(k_dram_window, {0, kK0});
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(s_acc,
......@@ -462,9 +467,16 @@ struct BlockFmhaPipelineQRKSVSAsync
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
});
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], k_tiles[I0]);
block_sync_lds();
gemm_0(s_acc,
q_tiles[number<k0_loops - 1>{}],
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
if(i_total_loops < num_total_loop - 1)
{
move_tile_window(k_dram_window, {kN0, -k0_loops * kK0});
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment