Commit 76b31460 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Re-arrange move_tile_window() statements (async)

parent 12871dd4
......@@ -343,10 +343,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto [i_page_block_v, v_dram_block_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths, {0, aligned_physical_seqlen_k_start});
// prefetch K tile
async_load_tile_raw(
......@@ -408,7 +406,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
// V DRAM tile window for load
auto v_dram_window = make_tile_window(
v_dram_block_window, Policy::template MakeVDramTileDistribution<Problem>());
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(s_acc,
......@@ -473,6 +475,20 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
}
move_tile_window(bias_dram_window, {0, kN0});
const auto v_origin = v_page_block_navigator.to_global_window_origin(
i_page_block_v, v_dram_window.get_window_origin());
if(i_total_loops < num_total_loop - 1)
{
// move K tile window
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
// move V tile window
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_block_window, {0, kN0});
}
// only check in first/last iterations
if constexpr(kHasUnevenSplits)
{
......@@ -560,8 +576,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{
if(v_page_block_navigator.is_last_block(i_page_block_v))
{
const auto v_origin = v_page_block_navigator.to_global_window_origin(
i_page_block_v, v_dram_window.get_window_origin());
set_tile_if(
v_buf,
type_convert<VDataType>(0.0),
......@@ -583,8 +597,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if constexpr(k1_loops > 1)
{
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, kK1});
move_tile_window(v_dram_window, {0, kK1});
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
......@@ -714,9 +727,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{
if(v_page_block_navigator.is_last_block(i_page_block_v_))
{
const auto v_origin =
v_page_block_navigator.to_global_window_origin(
i_page_block_v_, v_dram_window_.get_window_origin());
set_tile_if(v_buf,
type_convert<VDataType>(0.0),
[&, physical_seqlen_k_end_ = physical_seqlen_k_end](
......@@ -736,18 +746,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
}
if constexpr(i_k1 < k1_loops - 1)
i_page_block_v_ = v_page_block_navigator.move_tile_window(
i_page_block_v_, v_dram_window_, {0, kK1});
move_tile_window(v_dram_window_, {0, kK1});
});
}
i_total_loops++;
// load the first K tile for next iteration
if(i_total_loops < num_total_loop)
{
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
k_dram_window = make_tile_window(
k_dram_block_window, Policy::template MakeKDramTileDistribution<Problem>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment