Commit d362410d authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Use NumPrefetchV to separate from NumVLdsBuffers

parent 2e612c02
...@@ -173,6 +173,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -173,6 +173,7 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>(); constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>(); constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
static_assert(NumKLdsBuffers >= 2); static_assert(NumKLdsBuffers >= 2);
...@@ -250,7 +251,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -250,7 +251,7 @@ struct BlockFmhaPipelineQRKSVSAsync
using v_tile_type = decltype(load_tile(v_dram_window)); using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles; statically_indexed_array<v_tile_type, NumPrefetchV> v_tiles;
using v_lds_window_type = using v_lds_window_type =
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{})); decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}));
...@@ -468,7 +469,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -468,7 +469,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) { static_for<0, NumPrefetchV, 1>{}([&](auto i_buf) {
v_tiles[i_buf] = load_tile(v_dram_window); v_tiles[i_buf] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kK1});
}); });
...@@ -704,8 +705,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -704,8 +705,8 @@ struct BlockFmhaPipelineQRKSVSAsync
else else
{ {
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < k1_loops - NumVLdsBuffers) if constexpr(i_k1 < k1_loops - NumPrefetchV)
v_tiles[number<i_k1 % NumVLdsBuffers>{}] = load_tile(v_dram_window); v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
block_sync_lds(); block_sync_lds();
gemm_1(o_acc, gemm_1(o_acc,
...@@ -719,19 +720,19 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -719,19 +720,19 @@ struct BlockFmhaPipelineQRKSVSAsync
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>( auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>()); Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, shuffle_tile(v_shuffle_tmp,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]); v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffle_tmp)); tile_elementwise_in(v_element_func, v_shuffle_tmp));
} }
else else
{ {
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], store_tile(
tile_elementwise_in( v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
v_element_func, tile_elementwise_in(v_element_func,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}])); v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
} }
if constexpr(i_k1 < k1_loops - NumVLdsBuffers) if constexpr(i_k1 < k1_loops - NumPrefetchV)
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kK1});
}); });
} }
......
...@@ -97,6 +97,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy ...@@ -97,6 +97,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
return false; return false;
}; };
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumPrefetchV()
{
if constexpr(IsPreloadWholeNextIterationK<Problem>())
return GetNumVLdsBuffers<Problem>();
else
return min(2, GetNumVLdsBuffers<Problem>());
};
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment