Commit 5f4bfa4a authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Tune the prefetching of V in qr_ks_vs_async pipeline

parent 45398bf4
...@@ -161,13 +161,18 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -161,13 +161,18 @@ struct BlockFmhaPipelineQRKSVSAsync
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1; constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops); static_assert(2 <= k1_loops);
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>(); constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
static_assert(NumVLdsBuffers >= 2);
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
...@@ -366,7 +371,14 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -366,7 +371,14 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
auto v_buf = load_tile(v_dram_window); // prefetch load v tile using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles;
static_for<0, NumVLdsBuffers, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
...@@ -446,12 +458,12 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -446,12 +458,12 @@ struct BlockFmhaPipelineQRKSVSAsync
s.get_tile_distribution()); // Pcompute{j} s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>( auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>()); Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf); shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
auto v_lds_window_tmp = auto v_lds_window_tmp =
get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}); get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{});
...@@ -465,14 +477,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -465,14 +477,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto v_lds_window_tmp = auto v_lds_window_tmp =
get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}); get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{});
store_tile(v_lds_window_tmp, store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
}
move_tile_window(v_dram_window, {0, kK1});
if constexpr(NumVLdsBuffers > 1)
{
v_buf = load_tile(v_dram_window); // load next v_buf
move_tile_window(v_dram_window, {0, kK1});
} }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -569,7 +574,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -569,7 +574,8 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(NumVLdsBuffers == 1) if constexpr(NumVLdsBuffers == 1)
{ {
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
v_buf = load_tile(v_dram_window); // load next v_buf v_tiles[I0] = load_tile(v_dram_window);
block_sync_lds(); block_sync_lds();
gemm_1( gemm_1(
o_acc, o_acc,
...@@ -584,15 +590,14 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -584,15 +590,14 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>( auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>()); Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf); shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
auto v_lds_window_tmp = get_slice_tile( auto v_lds_window_tmp = get_slice_tile(
v_lds_window, v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{}, sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{}); sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
block_sync_lds(); block_sync_lds();
store_tile(v_lds_window_tmp, store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, tile_elementwise_in(v_element_func, v_shuffle_tmp));
v_shuffle_tmp)); // store the prefetch
} }
else else
{ {
...@@ -601,18 +606,18 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -601,18 +606,18 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{}, sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{}); sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
block_sync_lds(); block_sync_lds();
store_tile( store_tile(v_lds_window_tmp,
v_lds_window_tmp, tile_elementwise_in(v_element_func, v_tiles[I0]));
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
} }
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kK1});
}); });
} }
else else
{ {
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 > 0 && i_k1 < k1_loops - 1) if constexpr(i_k1 < k1_loops - NumVLdsBuffers)
v_buf = load_tile(v_dram_window); // load next v_buf v_tiles[number<i_k1 % NumVLdsBuffers>{}] = load_tile(v_dram_window);
block_sync_lds(); block_sync_lds();
gemm_1( gemm_1(
...@@ -628,14 +633,14 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -628,14 +633,14 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>( auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>()); Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf); shuffle_tile(v_shuffle_tmp,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]);
auto v_lds_window_tmp = get_slice_tile( auto v_lds_window_tmp = get_slice_tile(
v_lds_window, v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{}, sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{}); sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp, store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, tile_elementwise_in(v_element_func, v_shuffle_tmp));
v_shuffle_tmp)); // store the prefetch
} }
else else
{ {
...@@ -643,12 +648,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -643,12 +648,13 @@ struct BlockFmhaPipelineQRKSVSAsync
v_lds_window, v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{}, sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{}); sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
store_tile( store_tile(v_lds_window_tmp,
v_lds_window_tmp, tile_elementwise_in(
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf v_element_func,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]));
} }
if constexpr(i_k1 > 0 && i_k1 < k1_loops - 1) if constexpr(i_k1 < k1_loops - NumVLdsBuffers)
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kK1});
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment