"docs/git@developer.sourcefind.cn:SIYIXNI/vllm.git" did not exist on "2cf1a333b63a303fd4b65dd499f2e9b606e6525a"
Commit db952741 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Only check incomplete split in first&last iterations

parent 32ef8a18
......@@ -272,10 +272,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto [i_page_block_v, v_dram_block_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths, {0, aligned_physical_seqlen_k_start});
auto q_tile = tile_elementwise_in(q_element_func, q);
......@@ -289,10 +287,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
do
{
// STAGE 1, QK gemm
// K DRAM tile window for load
auto k_dram_window = make_tile_window(
k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
k_dram_block_window, Policy::template MakeKDramTileDistribution<Problem>());
auto k_block_tile = load_tile(k_dram_window);
{
......@@ -334,6 +331,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
k_block_tile = load_tile(k_dram_window); // global read i + 2
});
}
// V DRAM tile window for load
auto v_dram_window = make_tile_window(
v_dram_block_window, Policy::template MakeVDramTileDistribution<Problem>());
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
......@@ -402,27 +402,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in first/last iteration without increasing code size
// only check in first/last iterations
if constexpr(kHasUnevenSplits)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
{
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
}
else
{
return physical_seqlen_k_end_ <= col;
}
});
if(1 < num_splits && (i_total_loops == 0 || i_total_loops == num_total_loop - 1))
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
{
return col < physical_seqlen_k_start_ ||
physical_seqlen_k_end_ <= col;
}
else
{
return physical_seqlen_k_end_ <= col;
}
});
}
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
......@@ -444,6 +448,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
});
}
}
__builtin_amdgcn_sched_barrier(0);
// move K tile window
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
__builtin_amdgcn_sched_barrier(0);
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
......@@ -549,12 +558,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
i_page_block_v =
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
// moving v_dram_window is an in-page-block operation, so there is
// no need to invoke v_page_block_navigator.move_tile_window() here.
move_tile_window(v_dram_window, {0, kK1});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
__builtin_amdgcn_sched_barrier(0);
// move V tile window
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_block_window, {0, kN0});
__builtin_amdgcn_sched_barrier(0);
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
......@@ -582,13 +598,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
i_page_block_v_ = v_page_block_navigator.move_tile_window(
i_page_block_v_, v_dram_window_, {0, kK1});
move_tile_window(v_dram_window, {0, kK1});
});
}
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
// tail
{
block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment