Commit 3f29f232 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Only check incomplete split in first&last iterations

parent 4b3474e4
...@@ -402,27 +402,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -402,27 +402,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
move_tile_window(bias_dram_window, {0, kN0}); move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in first/last iteration without increasing code size // only check in first/last iterations
if constexpr(kHasUnevenSplits) if constexpr(kHasUnevenSplits)
{ {
const auto k_origin = k_page_block_navigator.to_global_window_origin( if(1 < num_splits && (i_total_loops == 0 || i_total_loops == num_total_loop - 1))
i_page_block_k, k_dram_block_window.get_window_origin()); {
set_tile_if( const auto k_origin = k_page_block_navigator.to_global_window_origin(
s_acc, i_page_block_k, k_dram_block_window.get_window_origin());
-numeric<SMPLComputeDataType>::infinity(), set_tile_if(s_acc,
[&, -numeric<SMPLComputeDataType>::infinity(),
physical_seqlen_k_start_ = physical_seqlen_k_start, [&,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { physical_seqlen_k_start_ = physical_seqlen_k_start,
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
if constexpr(kIsPagedKV) const auto col =
{ k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col; if constexpr(kIsPagedKV)
} {
else return col < physical_seqlen_k_start_ ||
{ physical_seqlen_k_end_ <= col;
return physical_seqlen_k_end_ <= col; }
} else
}); {
return physical_seqlen_k_end_ <= col;
}
});
}
} }
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment