Commit a72e100e authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Tune K Lds and V Lds reuse for kPreloadWholeNextIterationK == false

parent 53766479
......@@ -168,7 +168,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert(2 <= k0_loops);
static_assert(2 <= k1_loops);
constexpr bool kPreloadWholeNextIterationK = (kM0 <= 64);
constexpr bool kPreloadWholeNextIterationK =
Policy::template IsPreloadWholeNextIterationK<Problem>();
constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
......@@ -247,7 +248,8 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeVDramTileDistribution<Problem>());
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetExclusiveKLdsBytes<Problem>()),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
......@@ -450,13 +452,13 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else // only preload one unroll of K for next iteration
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}], k_tiles[I0]);
if constexpr(i_k0 == 0)
clear_tile(s_acc);
k_tiles[I0] = load_tile(k_dram_window);
if constexpr(i_k0 < k0_loops - 1)
k_tiles[I0] = load_tile(k_dram_window);
if constexpr(i_k0 < k0_loops - 2)
move_tile_window(k_dram_window, {0, kK0});
......@@ -466,20 +468,6 @@ struct BlockFmhaPipelineQRKSVSAsync
q_tiles[number<i_k0>{}],
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
});
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], k_tiles[I0]);
block_sync_lds();
gemm_0(s_acc,
q_tiles[number<k0_loops - 1>{}],
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
if(i_total_loops < num_total_loop - 1)
{
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
};
__builtin_amdgcn_sched_barrier(0);
......@@ -673,6 +661,20 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
__builtin_amdgcn_sched_barrier(0);
if constexpr(!kPreloadWholeNextIterationK)
{
if(i_total_loops < num_total_loop - 1)
{
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
}
__builtin_amdgcn_sched_barrier(0);
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
......
......@@ -88,12 +88,40 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsPreloadWholeNextIterationK()
{
if constexpr(Problem::BlockFmhaShape::kM0 <= 64)
return true;
else
return false;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes()
{
if constexpr(IsPreloadWholeNextIterationK<Problem>())
{
return 0;
}
else
{
constexpr index_t unreusable_k_lds_bytes =
GetSmemSizeK<Problem>() / GetNumKLdsBuffers<Problem>();
constexpr index_t unreusable_k_lds_bytes_aligned =
((unreusable_k_lds_bytes + 127) / 128) * 128;
return unreusable_k_lds_bytes_aligned;
};
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// assume V can reuse the shared memory by K
// assume V can reuse the other shared memory by K except the first
// assume Dropout can reuse the shared memory by V
return max(GetSmemSizeK<Problem>(),
return GetExclusiveKLdsBytes<Problem>() +
max(GetSmemSizeK<Problem>() - GetExclusiveKLdsBytes<Problem>(),
max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0)));
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment