Commit b1da29ba authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Switch to separate code blocks according to iteration index

parent 90e99a95
......@@ -169,8 +169,10 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert(2 <= k0_loops);
static_assert(2 <= k1_loops);
constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
static_assert(NumKLdsBuffers >= 2);
static_assert(NumVLdsBuffers >= 2);
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
......@@ -190,8 +192,8 @@ struct BlockFmhaPipelineQRKSVSAsync
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kSubQKHeaddim>{}), {0, 0});
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
......@@ -261,23 +263,18 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
auto k_dram_window =
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
using k_tile_type = decltype(load_tile(k_dram_window));
statically_indexed_array<k_tile_type, k0_loops> k_tiles;
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
k_tiles[i_k0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
});
move_tile_window(k_dram_window, {0, -k0_loops * kK0});
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
......@@ -318,7 +315,6 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0);
// prefetch K tile
index_t i_total_loops = 0;
// ensure loading of Q from LDS completely done
......@@ -326,51 +322,85 @@ struct BlockFmhaPipelineQRKSVSAsync
do
{
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
auto k_lds_window_tmp = get_slice_tile(
k_lds_window, sequence<i_k0 * kN0, 0>{}, sequence<(i_k0 + 1) * kN0, kK0>{});
if(i_total_loops == 0) // executed by fist iteration
{
auto k_lds_window_tmp =
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[I0]);
store_tile(k_lds_window_tmp, k_tiles[i_k0]);
});
clear_tile(s_acc); // initialize C
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(0);
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
if constexpr(i_k0 < k0_loops - 1)
{
k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
if(i_total_loops < num_total_loop - 1)
{
move_tile_window(k_dram_window, {kN0, 0});
__builtin_amdgcn_sched_barrier(0);
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
k_tiles[i_k0] = load_tile(k_dram_window);
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
move_tile_window(k_dram_window, {0, kK0});
if constexpr(i_k0 < k0_loops - 1)
{
k_lds_window_tmp = get_slice_tile(
k_lds_window,
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0 + 1>{}]);
};
});
move_tile_window(k_dram_window, {0, -k0_loops * kK0});
}
else // executed by intermediate and last iteration
{
clear_tile(s_acc); // initialize C
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
auto k_lds_window_tmp =
get_slice_tile(k_lds_window,
sequence<(i_k0 % NumKLdsBuffers) * kN0, 0>{},
sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0>{}]);
block_sync_lds();
// execute last unroll of gemm_0
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
});
};
__builtin_amdgcn_sched_barrier(0);
// ensure k is completely updated on LDS
block_sync_lds();
// executed by first and intermediate iteration
if(i_total_loops < num_total_loop - 1)
{
move_tile_window(k_dram_window, {kN0, 0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
gemm_0(
s_acc,
get_slice_tile(q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_window,
sequence<i_k0 * kN0, 0>{},
sequence<(i_k0 + 1) * kN0, kK0>{}));
});
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
__builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
});
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
}
__builtin_amdgcn_sched_barrier(0);
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles;
......@@ -519,7 +549,7 @@ struct BlockFmhaPipelineQRKSVSAsync
}
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
......
......@@ -296,11 +296,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
if constexpr(KLoadOnce)
{
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0;
return k0_loops;
return 2;
}
else
return 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment