Commit 76b31460 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Re-arrange move_tile_window() statements (async)

parent 12871dd4
...@@ -343,10 +343,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -343,10 +343,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
aligned_physical_seqlen_k_start)}, // M/N aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( auto [i_page_block_v, v_dram_block_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths, v_dram_block_window_lengths, {0, aligned_physical_seqlen_k_start});
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile // prefetch K tile
async_load_tile_raw( async_load_tile_raw(
...@@ -408,7 +406,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -408,7 +406,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{}); // V DRAM tile window for load
auto v_dram_window = make_tile_window(
v_dram_block_window, Policy::template MakeVDramTileDistribution<Problem>());
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
{ // tail { // tail
gemm_0(s_acc, gemm_0(s_acc,
...@@ -473,6 +475,20 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -473,6 +475,20 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
} }
move_tile_window(bias_dram_window, {0, kN0}); move_tile_window(bias_dram_window, {0, kN0});
const auto v_origin = v_page_block_navigator.to_global_window_origin(
i_page_block_v, v_dram_window.get_window_origin());
if(i_total_loops < num_total_loop - 1)
{
// move K tile window
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
// move V tile window
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_block_window, {0, kN0});
}
// only check in first/last iterations // only check in first/last iterations
if constexpr(kHasUnevenSplits) if constexpr(kHasUnevenSplits)
{ {
...@@ -560,8 +576,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -560,8 +576,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{ {
if(v_page_block_navigator.is_last_block(i_page_block_v)) if(v_page_block_navigator.is_last_block(i_page_block_v))
{ {
const auto v_origin = v_page_block_navigator.to_global_window_origin(
i_page_block_v, v_dram_window.get_window_origin());
set_tile_if( set_tile_if(
v_buf, v_buf,
type_convert<VDataType>(0.0), type_convert<VDataType>(0.0),
...@@ -583,8 +597,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -583,8 +597,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
i_page_block_v = v_page_block_navigator.move_tile_window( move_tile_window(v_dram_window, {0, kK1});
i_page_block_v, v_dram_window, {0, kK1});
v_buf = load_tile( v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
} }
...@@ -714,9 +727,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -714,9 +727,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{ {
if(v_page_block_navigator.is_last_block(i_page_block_v_)) if(v_page_block_navigator.is_last_block(i_page_block_v_))
{ {
const auto v_origin =
v_page_block_navigator.to_global_window_origin(
i_page_block_v_, v_dram_window_.get_window_origin());
set_tile_if(v_buf, set_tile_if(v_buf,
type_convert<VDataType>(0.0), type_convert<VDataType>(0.0),
[&, physical_seqlen_k_end_ = physical_seqlen_k_end]( [&, physical_seqlen_k_end_ = physical_seqlen_k_end](
...@@ -736,18 +746,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -736,18 +746,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
} }
if constexpr(i_k1 < k1_loops - 1) if constexpr(i_k1 < k1_loops - 1)
i_page_block_v_ = v_page_block_navigator.move_tile_window( move_tile_window(v_dram_window_, {0, kK1});
i_page_block_v_, v_dram_window_, {0, kK1});
}); });
} }
i_total_loops++; i_total_loops++;
// load the first K tile for next iteration // load the first K tile for next iteration
if(i_total_loops < num_total_loop) if(i_total_loops < num_total_loop)
{ {
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
k_dram_window = make_tile_window( k_dram_window = make_tile_window(
k_dram_block_window, Policy::template MakeKDramTileDistribution<Problem>()); k_dram_block_window, Policy::template MakeKDramTileDistribution<Problem>());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment