Commit 87b206fb authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Define statically indexed array v_lds_windows[] to reduce using of get_slice_tile()

parent cde3b677
...@@ -313,6 +313,16 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -313,6 +313,16 @@ struct BlockFmhaPipelineQRKSVSAsync
statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles; statically_indexed_array<v_tile_type, NumVLdsBuffers> v_tiles;
using v_lds_window_type =
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}));
statically_indexed_array<v_lds_window_type, NumVLdsBuffers> v_lds_windows;
static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) {
v_lds_windows[i_buf] = get_slice_tile(
v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
});
index_t i_total_loops = 0; index_t i_total_loops = 0;
do do
...@@ -643,18 +653,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -643,18 +653,13 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeShuffledVRegBlockDescriptor<Problem>()); Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_tiles[I0]); shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{});
store_tile( store_tile(
v_lds_window_tmp, v_lds_windows[I0],
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
} }
else else
{ {
auto v_lds_window_tmp = store_tile(v_lds_windows[I0],
get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
} }
...@@ -672,13 +677,10 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -672,13 +677,10 @@ struct BlockFmhaPipelineQRKSVSAsync
v_tiles[I0] = load_tile(v_dram_window); v_tiles[I0] = load_tile(v_dram_window);
block_sync_lds(); block_sync_lds();
gemm_1( gemm_1(o_acc,
o_acc, get_slice_tile(
get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}), v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
get_slice_tile(v_lds_window,
sequence<(i_k1 % NumVLdsBuffers) * kN1, 0>{},
sequence<((i_k1 % NumVLdsBuffers) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, if constexpr(std::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>) ck_tile::tensor_layout::gemm::RowMajor>)
...@@ -686,22 +688,14 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -686,22 +688,14 @@ struct BlockFmhaPipelineQRKSVSAsync
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>( auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>()); Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_tiles[I0]); shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
block_sync_lds(); block_sync_lds();
store_tile(v_lds_window_tmp, store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffle_tmp)); tile_elementwise_in(v_element_func, v_shuffle_tmp));
} }
else else
{ {
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
block_sync_lds(); block_sync_lds();
store_tile(v_lds_window_tmp, store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[I0])); tile_elementwise_in(v_element_func, v_tiles[I0]));
} }
...@@ -715,13 +709,10 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -715,13 +709,10 @@ struct BlockFmhaPipelineQRKSVSAsync
v_tiles[number<i_k1 % NumVLdsBuffers>{}] = load_tile(v_dram_window); v_tiles[number<i_k1 % NumVLdsBuffers>{}] = load_tile(v_dram_window);
block_sync_lds(); block_sync_lds();
gemm_1( gemm_1(o_acc,
o_acc, get_slice_tile(
get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}), v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
get_slice_tile(v_lds_window,
sequence<(i_k1 % NumVLdsBuffers) * kN1, 0>{},
sequence<((i_k1 % NumVLdsBuffers) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, if constexpr(std::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>) ck_tile::tensor_layout::gemm::RowMajor>)
...@@ -730,20 +721,12 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -730,20 +721,12 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeShuffledVRegBlockDescriptor<Problem>()); Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, shuffle_tile(v_shuffle_tmp,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]); v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]);
auto v_lds_window_tmp = get_slice_tile( store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); tile_elementwise_in(v_element_func, v_shuffle_tmp));
} }
else else
{ {
auto v_lds_window_tmp = get_slice_tile( store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
v_lds_window,
sequence<((i_k1 + 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((i_k1 + 1) % NumVLdsBuffers) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in( tile_elementwise_in(
v_element_func, v_element_func,
v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}])); v_tiles[number<(i_k1 + 1) % NumVLdsBuffers>{}]));
...@@ -759,12 +742,9 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -759,12 +742,9 @@ struct BlockFmhaPipelineQRKSVSAsync
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
gemm_1( gemm_1(o_acc,
o_acc, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}), v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]);
get_slice_tile(v_lds_window,
sequence<((k1_loops - 1) % NumVLdsBuffers) * kN1, 0>{},
sequence<(((k1_loops - 1) % NumVLdsBuffers) + 1) * kN1, kK1>{}));
} }
} while(++i_total_loops < num_total_loop); } while(++i_total_loops < num_total_loop);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment