"tests/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "717d15719c713fd3ee9ab0d8eb3d98116758036e"
Commit b1da29ba authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Switch to separate code blocks according to iteration index

parent 90e99a95
...@@ -169,8 +169,10 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -169,8 +169,10 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
static_assert(2 <= k1_loops); static_assert(2 <= k1_loops);
constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>(); constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
static_assert(NumKLdsBuffers >= 2);
static_assert(NumVLdsBuffers >= 2); static_assert(NumVLdsBuffers >= 2);
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -190,8 +192,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -190,8 +192,8 @@ struct BlockFmhaPipelineQRKSVSAsync
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr); KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>( auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>()); k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window = auto k_lds_window = make_tile_window(
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kSubQKHeaddim>{}), {0, 0}); k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// V tile in LDS // V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>( auto v_lds = make_tensor_view<address_space_enum::lds>(
...@@ -261,23 +263,18 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -261,23 +263,18 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window_tmp.get_window_lengths(), k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0}); {seqlen_k_start, 0});
auto k_dram_window = make_tile_window( auto k_dram_window =
k_dram_block_window.get_bottom_tensor_view(), make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(), k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>());
using k_tile_type = decltype(load_tile(k_dram_window)); using k_tile_type = decltype(load_tile(k_dram_window));
statically_indexed_array<k_tile_type, k0_loops> k_tiles; statically_indexed_array<k_tile_type, k0_loops> k_tiles;
static_for<0, k0_loops, 1>{}([&](auto i_k0) { k_tiles[I0] = load_tile(k_dram_window);
k_tiles[i_k0] = load_tile(k_dram_window); move_tile_window(k_dram_window, {0, kK0});
move_tile_window(k_dram_window, {0, kK0});
});
move_tile_window(k_dram_window, {0, -k0_loops * kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -318,7 +315,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -318,7 +315,6 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// prefetch K tile
index_t i_total_loops = 0; index_t i_total_loops = 0;
// ensure loading of Q from LDS completely done // ensure loading of Q from LDS completely done
...@@ -326,51 +322,85 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -326,51 +322,85 @@ struct BlockFmhaPipelineQRKSVSAsync
do do
{ {
static_for<0, k0_loops, 1>{}([&](auto i_k0) { if(i_total_loops == 0) // executed by fist iteration
auto k_lds_window_tmp = get_slice_tile( {
k_lds_window, sequence<i_k0 * kN0, 0>{}, sequence<(i_k0 + 1) * kN0, kK0>{}); auto k_lds_window_tmp =
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[I0]);
store_tile(k_lds_window_tmp, k_tiles[i_k0]); clear_tile(s_acc); // initialize C
});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// STAGE 1, QK gemm static_for<0, k0_loops, 1>{}([&](auto i_k0) {
clear_tile(s_acc); // initialize C if constexpr(i_k0 < k0_loops - 1)
{
k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
if(i_total_loops < num_total_loop - 1) __builtin_amdgcn_sched_barrier(0);
{
move_tile_window(k_dram_window, {kN0, 0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) { block_sync_lds();
k_tiles[i_k0] = load_tile(k_dram_window); // execute current unroll of gemm_0
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
move_tile_window(k_dram_window, {0, kK0}); if constexpr(i_k0 < k0_loops - 1)
{
k_lds_window_tmp = get_slice_tile(
k_lds_window,
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0 + 1>{}]);
};
}); });
move_tile_window(k_dram_window, {0, -k0_loops * kK0}); move_tile_window(k_dram_window, {0, -k0_loops * kK0});
} }
else // executed by intermediate and last iteration
{
clear_tile(s_acc); // initialize C
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
auto k_lds_window_tmp =
get_slice_tile(k_lds_window,
sequence<(i_k0 % NumKLdsBuffers) * kN0, 0>{},
sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0>{}]);
block_sync_lds();
// execute last unroll of gemm_0
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
});
};
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// ensure k is completely updated on LDS // executed by first and intermediate iteration
block_sync_lds(); if(i_total_loops < num_total_loop - 1)
{
move_tile_window(k_dram_window, {kN0, 0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) { static_for<0, k0_loops, 1>{}([&](auto i_k0) {
gemm_0( k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
s_acc,
get_slice_tile(q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_window,
sequence<i_k0 * kN0, 0>{},
sequence<(i_k0 + 1) * kN0, kK0>{}));
});
__builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
});
const auto bias_tile = load_tile(bias_dram_window); // load bias tile move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
}
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
using v_tile_type = decltype(load_tile(v_dram_window)); using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles; statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles;
...@@ -519,7 +549,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -519,7 +549,7 @@ struct BlockFmhaPipelineQRKSVSAsync
} }
}(); }();
#else #else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif #endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
......
...@@ -296,11 +296,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -296,11 +296,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{ {
if constexpr(KLoadOnce) if constexpr(KLoadOnce)
{ {
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; return 2;
constexpr index_t k0_loops = BlockFmhaShape::kQKHeaddim / BlockFmhaShape::kK0;
return k0_loops;
} }
else else
return 1; return 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment