"vscode:/vscode.git/clone" did not exist on "571f5efa1d55cd0ef1581977b54358a95f3674be"
Commit a72e100e authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Tune K Lds and V Lds reuse for kPreloadWholeNextIterationK == false

parent 53766479
...@@ -168,7 +168,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -168,7 +168,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert(2 <= k0_loops); static_assert(2 <= k0_loops);
static_assert(2 <= k1_loops); static_assert(2 <= k1_loops);
constexpr bool kPreloadWholeNextIterationK = (kM0 <= 64); constexpr bool kPreloadWholeNextIterationK =
Policy::template IsPreloadWholeNextIterationK<Problem>();
constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>(); constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>(); constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
...@@ -247,7 +248,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -247,7 +248,8 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeVDramTileDistribution<Problem>()); Policy::template MakeVDramTileDistribution<Problem>());
// V tile in LDS // V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>( auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr), reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetExclusiveKLdsBytes<Problem>()),
Policy::template MakeVLdsBlockDescriptor<Problem>()); Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window( auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0}); v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
...@@ -450,13 +452,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -450,13 +452,13 @@ struct BlockFmhaPipelineQRKSVSAsync
} }
else // only preload one unroll of K for next iteration else // only preload one unroll of K for next iteration
{ {
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { static_for<0, k0_loops, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}], k_tiles[I0]); store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}], k_tiles[I0]);
if constexpr(i_k0 == 0) if constexpr(i_k0 == 0)
clear_tile(s_acc); clear_tile(s_acc);
k_tiles[I0] = load_tile(k_dram_window); if constexpr(i_k0 < k0_loops - 1)
k_tiles[I0] = load_tile(k_dram_window);
if constexpr(i_k0 < k0_loops - 2) if constexpr(i_k0 < k0_loops - 2)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
...@@ -466,20 +468,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -466,20 +468,6 @@ struct BlockFmhaPipelineQRKSVSAsync
q_tiles[number<i_k0>{}], q_tiles[number<i_k0>{}],
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]); k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
}); });
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], k_tiles[I0]);
block_sync_lds();
gemm_0(s_acc,
q_tiles[number<k0_loops - 1>{}],
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
if(i_total_loops < num_total_loop - 1)
{
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
}; };
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -673,6 +661,20 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -673,6 +661,20 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto p = const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)); cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
__builtin_amdgcn_sched_barrier(0);
if constexpr(!kPreloadWholeNextIterationK)
{
if(i_total_loops < num_total_loop - 1)
{
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
}
__builtin_amdgcn_sched_barrier(0);
// STAGE 3, KV gemm // STAGE 3, KV gemm
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
......
...@@ -88,12 +88,40 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy ...@@ -88,12 +88,40 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsPreloadWholeNextIterationK()
{
if constexpr(Problem::BlockFmhaShape::kM0 <= 64)
return true;
else
return false;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes()
{
if constexpr(IsPreloadWholeNextIterationK<Problem>())
{
return 0;
}
else
{
constexpr index_t unreusable_k_lds_bytes =
GetSmemSizeK<Problem>() / GetNumKLdsBuffers<Problem>();
constexpr index_t unreusable_k_lds_bytes_aligned =
((unreusable_k_lds_bytes + 127) / 128) * 128;
return unreusable_k_lds_bytes_aligned;
};
};
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
// assume V can reuse the shared memory by K // assume V can reuse the other shared memory by K except the first
// assume Dropout can reuse the shared memory by V // assume Dropout can reuse the shared memory by V
return max(GetSmemSizeK<Problem>(), return GetExclusiveKLdsBytes<Problem>() +
max(GetSmemSizeK<Problem>() - GetExclusiveKLdsBytes<Problem>(),
max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0))); max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0)));
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment