Commit 5f4bfa4a authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Tune the prefetching of V in qr_ks_vs_async pipeline

parent 45398bf4
......@@ -161,13 +161,18 @@ struct BlockFmhaPipelineQRKSVSAsync
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops);
static_assert(2 <= k1_loops);
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
static_assert(NumVLdsBuffers >= 2);
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
......@@ -366,7 +371,14 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0);
auto v_buf = load_tile(v_dram_window); // prefetch load v tile
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles;
static_for<0, NumVLdsBuffers, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
......@@ -446,12 +458,12 @@ struct BlockFmhaPipelineQRKSVSAsync
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{});
......@@ -465,14 +477,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto v_lds_window_tmp =
get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
move_tile_window(v_dram_window, {0, kK1});
if constexpr(NumVLdsBuffers > 1)
{
v_buf = load_tile(v_dram_window); // load next v_buf
move_tile_window(v_dram_window, {0, kK1});
tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
}
__builtin_amdgcn_sched_barrier(0);
......@@ -569,7 +574,8 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(NumVLdsBuffers == 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
v_buf = load_tile(v_dram_window); // load next v_buf
v_tiles[I0] = load_tile(v_dram_window);
block_sync_lds();
gemm_1(
o_acc,
......@@ -584,15 +590,14 @@ struct BlockFmhaPipelineQRKSVSAsync
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
block_sync_lds();
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
tile_elementwise_in(v_element_func, v_shuffle_tmp));
}
else
{
......@@ -601,18 +606,18 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
block_sync_lds();
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_tiles[I0]));
}
move_tile_window(v_dram_window, {0, kK1});
});
}
else
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 > 0 && i_k1 < k1_loops - 1)
v_buf = load_tile(v_dram_window); // load next v_buf
if constexpr(i_k1 < k1_loops - NumVLdsBuffers)
v_tiles[number<i_k1 % NumVLdsBuffers>{}] = load_tile(v_dram_window);
block_sync_lds();
gemm_1(
......@@ -628,14 +633,14 @@ struct BlockFmhaPipelineQRKSVSAsync
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
shuffle_tile(v_shuffle_tmp,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
tile_elementwise_in(v_element_func, v_shuffle_tmp));
}
else
{
......@@ -643,12 +648,13 @@ struct BlockFmhaPipelineQRKSVSAsync
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
store_tile(v_lds_window_tmp,
tile_elementwise_in(
v_element_func,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]));
}
if constexpr(i_k1 > 0 && i_k1 < k1_loops - 1)
if constexpr(i_k1 < k1_loops - NumVLdsBuffers)
move_tile_window(v_dram_window, {0, kK1});
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment