Commit 87b206fb authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Define statically indexed array v_lds_windows[] to reduce using of get_slice_tile()

parent cde3b677
......@@ -313,6 +313,16 @@ struct BlockFmhaPipelineQRKSVSAsync
statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles;
using v_lds_window_type =
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}));
statically_indexed_array<v_lds_window_type, NumVLdsBuffers> v_lds_windows;
static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) {
v_lds_windows[i_buf] = get_slice_tile(
v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
});
index_t i_total_loops = 0;
do
......@@ -643,18 +653,13 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{});
store_tile(
v_lds_window_tmp,
v_lds_windows[I0],
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{});
store_tile(v_lds_window_tmp,
store_tile(v_lds_windows[I0],
tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
}
......@@ -672,13 +677,10 @@ struct BlockFmhaPipelineQRKSVSAsync
v_tiles[I0] = load_tile(v_dram_window);
block_sync_lds();
gemm_1(
o_acc,
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(v_lds_window,
sequence<(i_k1 % NumVLdsBuffers) * kN1, 0>{},
sequence<((i_k1 % NumVLdsBuffers) + 1) * kN1, kK1>{}));
v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
if constexpr(std::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
......@@ -686,22 +688,14 @@ struct BlockFmhaPipelineQRKSVSAsync
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
block_sync_lds();
store_tile(v_lds_window_tmp,
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffle_tmp));
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
block_sync_lds();
store_tile(v_lds_window_tmp,
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[I0]));
}
......@@ -715,13 +709,10 @@ struct BlockFmhaPipelineQRKSVSAsync
v_tiles[number<i_k1 % NumVLdsBuffers>{}] = load_tile(v_dram_window);
block_sync_lds();
gemm_1(
o_acc,
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(v_lds_window,
sequence<(i_k1 % NumVLdsBuffers) * kN1, 0>{},
sequence<((i_k1 % NumVLdsBuffers) + 1) * kN1, kK1>{}));
v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
if constexpr(std::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
......@@ -730,20 +721,12 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffle_tmp));
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(
v_element_func,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]));
......@@ -759,12 +742,9 @@ struct BlockFmhaPipelineQRKSVSAsync
// tail
{
block_sync_lds();
gemm_1(
o_acc,
gemm_1(o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(v_lds_window,
sequence<((k1_loops - 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((k1_loops - 1) % NumVLdsBuffers) + 1) * kN1, kK1>{}));
v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]);
}
} while(++i_total_loops < num_total_loop);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment