Commit db952741 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Only check incomplete split in first&last iterations

parent 32ef8a18
...@@ -272,10 +272,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -272,10 +272,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
aligned_physical_seqlen_k_start)}, // M/N aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( auto [i_page_block_v, v_dram_block_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths, v_dram_block_window_lengths, {0, aligned_physical_seqlen_k_start});
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q); auto q_tile = tile_elementwise_in(q_element_func, q);
...@@ -289,10 +287,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -289,10 +287,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
do do
{ {
// STAGE 1, QK gemm // STAGE 1, QK gemm
// K DRAM tile window for load
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram_block_window, k_dram_block_window, Policy::template MakeKDramTileDistribution<Problem>());
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
auto k_block_tile = load_tile(k_dram_window); auto k_block_tile = load_tile(k_dram_window);
{ {
...@@ -334,6 +331,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -334,6 +331,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
k_block_tile = load_tile(k_dram_window); // global read i + 2 k_block_tile = load_tile(k_dram_window); // global read i + 2
}); });
} }
// V DRAM tile window for load
auto v_dram_window = make_tile_window(
v_dram_block_window, Policy::template MakeVDramTileDistribution<Problem>());
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail { // tail
...@@ -402,21 +402,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -402,21 +402,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
move_tile_window(bias_dram_window, {0, kN0}); move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in first/last iteration without increasing code size // only check in first/last iterations
if constexpr(kHasUnevenSplits) if constexpr(kHasUnevenSplits)
{
if(1 < num_splits && (i_total_loops == 0 || i_total_loops == num_total_loop - 1))
{ {
const auto k_origin = k_page_block_navigator.to_global_window_origin( const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin()); i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if( set_tile_if(s_acc,
s_acc,
-numeric<SMPLComputeDataType>::infinity(), -numeric<SMPLComputeDataType>::infinity(),
[&, [&,
physical_seqlen_k_start_ = physical_seqlen_k_start, physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV) if constexpr(kIsPagedKV)
{ {
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col; return col < physical_seqlen_k_start_ ||
physical_seqlen_k_end_ <= col;
} }
else else
{ {
...@@ -424,6 +427,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -424,6 +427,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
}); });
} }
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{ {
...@@ -444,6 +448,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -444,6 +448,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}); });
} }
} }
__builtin_amdgcn_sched_barrier(0);
// move K tile window
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
__builtin_amdgcn_sched_barrier(0);
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j} const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>( auto m_local = block_tile_reduce<SMPLComputeDataType>(
...@@ -549,12 +558,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -549,12 +558,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(v_lds_window, store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
} }
i_page_block_v = // moving v_dram_window is an in-page-block operation, so there is
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1}); // no need to invoke v_page_block_navigator.move_tile_window() here.
move_tile_window(v_dram_window, {0, kK1});
const auto p = const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)); cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
__builtin_amdgcn_sched_barrier(0);
// move V tile window
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_block_window, {0, kN0});
__builtin_amdgcn_sched_barrier(0);
// STAGE 3, KV gemm // STAGE 3, KV gemm
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
...@@ -582,13 +598,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -582,13 +598,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(v_lds_window, store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v tile_elementwise_in(v_element_func, v)); // store next v
} }
i_page_block_v_ = v_page_block_navigator.move_tile_window( move_tile_window(v_dram_window, {0, kK1});
i_page_block_v_, v_dram_window_, {0, kK1});
}); });
} }
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment