Commit 2e612c02 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Adjust the pipeline codes

parent a72e100e
......@@ -182,23 +182,10 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeQRegTileDistribution<Problem>());
using q_tile_type = decltype(load_tile(q_dram_window));
statically_indexed_array<q_tile_type, k0_loops> q_tiles;
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
q_tiles[number<i_k0>{}] = load_tile(q_dram_window);
move_tile_window(q_dram_window, {0, kK0});
});
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
__builtin_amdgcn_sched_barrier(0);
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
......@@ -222,6 +209,15 @@ struct BlockFmhaPipelineQRKSVSAsync
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
statically_indexed_array<q_tile_type, k0_loops> q_tiles;
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
q_tiles[number<i_k0>{}] = load_tile(q_dram_window);
move_tile_window(q_dram_window, {0, kK0});
});
__builtin_amdgcn_sched_barrier(0);
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
......@@ -239,8 +235,6 @@ struct BlockFmhaPipelineQRKSVSAsync
k_lds_window, sequence<i_buf * kN0, 0>{}, sequence<(i_buf + 1) * kN0, kK0>{});
});
__builtin_amdgcn_sched_barrier(0);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
......@@ -268,8 +262,6 @@ struct BlockFmhaPipelineQRKSVSAsync
v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
});
__builtin_amdgcn_sched_barrier(0);
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
......@@ -298,6 +290,8 @@ struct BlockFmhaPipelineQRKSVSAsync
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{
......@@ -661,8 +655,6 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
__builtin_amdgcn_sched_barrier(0);
if constexpr(!kPreloadWholeNextIterationK)
{
if(i_total_loops < num_total_loop - 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment