Commit 65bbe6ea authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use vector load if paged-vcache is in column major

parent 1fef9106
......@@ -717,10 +717,15 @@ struct FmhaFwdSplitKVKernel
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
// We assume that page-block size is always divisible by vector size. So we can use
// vector load on seqlen_k direction. However, the seqlen_k may not be divisible by
// vector size as well. So we will have to override data points which are located
// outside [0, seqlen_k) to 0.0 in pipeline.
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<false, kPadSeqLenK>{});
sequence < false,
!kIsPagedKV && kPadSeqLenK > {});
}
};
const auto v_dram = [&]() {
......@@ -786,9 +791,14 @@ struct FmhaFwdSplitKVKernel
num_blocks,
kargs.page_block_size,
v_dram,
make_v_dram(nullptr,
(kv_l2p_offset + kargs.seqlen_k) -
(num_blocks - 1) * kargs.page_block_size));
make_v_dram(nullptr, [&] {
if constexpr(std::is_same_v<VLayout,
ck_tile::tensor_layout::gemm::RowMajor>)
return (kv_l2p_offset + kargs.seqlen_k) -
(num_blocks - 1) * kargs.page_block_size;
else
return kargs.page_block_size;
}()));
}
else
{
......
......@@ -63,7 +63,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
// We assume that page-block size is always divisible by vector size. So we can use
// vector load on seqlen_k direction. However, the seqlen_k may not be divisible by
// vector size as well. So we will have to override data points which are located
// outside [0, seqlen_k) to 0.0.
return kIsPagedKV ? Policy::template GetAlignmentV<Problem>()
: kPadSeqLenK ? 1
: Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentOacc =
......@@ -335,8 +341,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
......@@ -550,6 +556,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
else
{
// Override data points which are located outside [0, seqlen_k) to 0.0
if constexpr(kIsPagedKV && kPadSeqLenK)
{
if(v_page_block_navigator.is_last_block(i_page_block_v))
{
const auto v_origin = v_page_block_navigator.to_global_window_origin(
i_page_block_v, v_dram_window.get_window_origin());
set_tile_if(
v_prefetch,
type_convert<VDataType>(0.0),
[&, physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col =
v_origin.at(number<1>{}) + tile_idx.at(number<1>{});
return physical_seqlen_k_end_ <= col;
});
}
}
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
......@@ -565,7 +589,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static_for<0, k1_loops - 1, 1>{}([&,
&i_page_block_v_ = i_page_block_v,
&v_dram_window_ = v_dram_window](auto i_k1) {
const auto v = load_tile(v_dram_window_); // load next v
auto v = load_tile(v_dram_window_); // load next v
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
......@@ -583,6 +607,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
else
{
// Override data points which are located outside [0, seqlen_k) to 0.0
if constexpr(kIsPagedKV && kPadSeqLenK)
{
if(v_page_block_navigator.is_last_block(i_page_block_v_))
{
const auto v_origin =
v_page_block_navigator.to_global_window_origin(
i_page_block_v_, v_dram_window_.get_window_origin());
set_tile_if(v,
type_convert<VDataType>(0.0),
[&, physical_seqlen_k_end_ = physical_seqlen_k_end](
auto tile_idx) {
const auto col = v_origin.at(number<1>{}) +
tile_idx.at(number<1>{});
return physical_seqlen_k_end_ <= col;
});
}
}
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment