Commit d3b01d2b authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Re-arrange the codes before the main-loop

parent d55852bc
......@@ -188,6 +188,38 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
__builtin_amdgcn_sched_barrier(0);
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
auto k_dram_window =
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
using k_tile_type = decltype(load_tile(k_dram_window));
auto k_tiles = [&]() {
// for hdim-96 and hdim-160, try to save vgprs
if constexpr(kQKHeaddim < kSubQKHeaddim)
return statically_indexed_array<k_tile_type, 2>{};
else
return statically_indexed_array<k_tile_type, k0_loops>{};
}();
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
......@@ -195,6 +227,23 @@ struct BlockFmhaPipelineQRKSVSAsync
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using k_lds_window_type =
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}));
statically_indexed_array<k_lds_window_type, NumKLdsBuffers> k_lds_windows;
static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_windows[i_buf] = get_slice_tile(
k_lds_window, sequence<i_buf * kN0, 0>{}, sequence<(i_buf + 1) * kN0, kK0>{});
});
__builtin_amdgcn_sched_barrier(0);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
......@@ -202,6 +251,22 @@ struct BlockFmhaPipelineQRKSVSAsync
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles;
using v_lds_window_type =
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}));
statically_indexed_array<v_lds_window_type, NumVLdsBuffers> v_lds_windows;
static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) {
v_lds_windows[i_buf] = get_slice_tile(
v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
});
__builtin_amdgcn_sched_barrier(0);
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
......@@ -230,12 +295,6 @@ struct BlockFmhaPipelineQRKSVSAsync
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{
......@@ -257,42 +316,6 @@ struct BlockFmhaPipelineQRKSVSAsync
}
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
auto k_dram_window =
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
using k_tile_type = decltype(load_tile(k_dram_window));
auto k_tiles = [&]() {
// for hdim-96 and hdim-160, try to save vgprs
if constexpr(kQKHeaddim < kSubQKHeaddim)
return statically_indexed_array<k_tile_type, 2>{};
else
return statically_indexed_array<k_tile_type, k0_loops>{};
}();
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
using k_lds_window_type =
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}));
statically_indexed_array<k_lds_window_type, NumKLdsBuffers> k_lds_windows;
static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_windows[i_buf] = get_slice_tile(
k_lds_window, sequence<i_buf * kN0, 0>{}, sequence<(i_buf + 1) * kN0, kK0>{});
});
__builtin_amdgcn_sched_barrier(0);
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
......@@ -303,26 +326,6 @@ struct BlockFmhaPipelineQRKSVSAsync
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles;
using v_lds_window_type =
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}));
statically_indexed_array<v_lds_window_type, NumVLdsBuffers> v_lds_windows;
static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) {
v_lds_windows[i_buf] = get_slice_tile(
v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
});
index_t i_total_loops = 0;
do
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment