"...resnet50_tensorflow.git" did not exist on "be4e155bade2cbeb89fff1022c284d72db8f0fd8"
Commit d362410d authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Use NumPrefetchV to separate from NumVLdsBuffers

parent 2e612c02
......@@ -173,6 +173,7 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
static_assert(NumKLdsBuffers >= 2);
......@@ -250,7 +251,7 @@ struct BlockFmhaPipelineQRKSVSAsync
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles;
statically_indexed_array<v_tile_type, NumPrefetchV> v_tiles;
using v_lds_window_type =
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}));
......@@ -468,7 +469,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) {
static_for<0, NumPrefetchV, 1>{}([&](auto i_buf) {
v_tiles[i_buf] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
......@@ -704,8 +705,8 @@ struct BlockFmhaPipelineQRKSVSAsync
else
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < k1_loops - NumVLdsBuffers)
v_tiles[number<i_k1 % NumVLdsBuffers>{}] = load_tile(v_dram_window);
if constexpr(i_k1 < k1_loops - NumPrefetchV)
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
block_sync_lds();
gemm_1(o_acc,
......@@ -719,19 +720,19 @@ struct BlockFmhaPipelineQRKSVSAsync
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]);
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffle_tmp));
}
else
{
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(
v_element_func,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]));
store_tile(
v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func,
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
}
if constexpr(i_k1 < k1_loops - NumVLdsBuffers)
if constexpr(i_k1 < k1_loops - NumPrefetchV)
move_tile_window(v_dram_window, {0, kK1});
});
}
......
......@@ -97,6 +97,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
return false;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumPrefetchV()
{
if constexpr(IsPreloadWholeNextIterationK<Problem>())
return GetNumVLdsBuffers<Problem>();
else
return min(2, GetNumVLdsBuffers<Problem>());
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment