Commit 36a1c7c9 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use vector load if paged-vcache is in column major (async pipeline)

parent 65bbe6ea
......@@ -67,7 +67,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
return kIsPagedKV ? Policy::template GetAlignmentV<Problem>()
: kPadSeqLenK ? 1
: Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentOacc =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
......@@ -555,6 +557,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
}
else
{
// Override data points which are located outside [0, seqlen_k) to 0.0
if constexpr(kIsPagedKV && kPadSeqLenK)
{
if(v_page_block_navigator.is_last_block(i_page_block_v))
{
const auto v_origin = v_page_block_navigator.to_global_window_origin(
i_page_block_v, v_dram_window.get_window_origin());
set_tile_if(
v_buf,
type_convert<VDataType>(0.0),
[&, physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col =
v_origin.at(number<1>{}) + tile_idx.at(number<1>{});
return physical_seqlen_k_end_ <= col;
});
}
}
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
......@@ -691,6 +711,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
}
else
{
// Override data points which are located outside [0, seqlen_k) to 0.0
if constexpr(kIsPagedKV && kPadSeqLenK)
{
if(v_page_block_navigator.is_last_block(i_page_block_v_))
{
const auto v_origin =
v_page_block_navigator.to_global_window_origin(
i_page_block_v_, v_dram_window_.get_window_origin());
set_tile_if(v_buf,
type_convert<VDataType>(0.0),
[&, physical_seqlen_k_end_ = physical_seqlen_k_end](
auto tile_idx) {
const auto col = v_origin.at(number<1>{}) +
tile_idx.at(number<1>{});
return physical_seqlen_k_end_ <= col;
});
}
}
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment