Commit 65bbe6ea authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use vector load if paged-vcache is in column major

parent 1fef9106
...@@ -717,10 +717,15 @@ struct FmhaFwdSplitKVKernel ...@@ -717,10 +717,15 @@ struct FmhaFwdSplitKVKernel
number<FmhaPipeline::kAlignmentV>{}, number<FmhaPipeline::kAlignmentV>{},
number<1>{}); number<1>{});
// We assume that page-block size is always divisible by vector size. So we can use
// vector load on seqlen_k direction. However, the seqlen_k may not be divisible by
// vector size as well. So we will have to override data points which are located
// outside [0, seqlen_k) to 0.0 in pipeline.
return pad_tensor_view( return pad_tensor_view(
v_dram_naive, v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<false, kPadSeqLenK>{}); sequence < false,
!kIsPagedKV && kPadSeqLenK > {});
} }
}; };
const auto v_dram = [&]() { const auto v_dram = [&]() {
...@@ -786,9 +791,14 @@ struct FmhaFwdSplitKVKernel ...@@ -786,9 +791,14 @@ struct FmhaFwdSplitKVKernel
num_blocks, num_blocks,
kargs.page_block_size, kargs.page_block_size,
v_dram, v_dram,
make_v_dram(nullptr, make_v_dram(nullptr, [&] {
(kv_l2p_offset + kargs.seqlen_k) - if constexpr(std::is_same_v<VLayout,
(num_blocks - 1) * kargs.page_block_size)); ck_tile::tensor_layout::gemm::RowMajor>)
return (kv_l2p_offset + kargs.seqlen_k) -
(num_blocks - 1) * kargs.page_block_size;
else
return kargs.page_block_size;
}()));
} }
else else
{ {
......
...@@ -63,7 +63,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -63,7 +63,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>(); return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>(); // We assume that page-block size is always divisible by vector size. So we can use
// vector load on seqlen_k direction. However, the seqlen_k may not be divisible by
// vector size as well. So we will have to override data points which are located
// outside [0, seqlen_k) to 0.0.
return kIsPagedKV ? Policy::template GetAlignmentV<Problem>()
: kPadSeqLenK ? 1
: Policy::template GetAlignmentV<Problem>();
}(); }();
static constexpr index_t kAlignmentOacc = static constexpr index_t kAlignmentOacc =
...@@ -335,8 +341,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -335,8 +341,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}); });
} }
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail { // tail
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, gemm_0(s_acc,
get_slice_tile(q_tile, get_slice_tile(q_tile,
...@@ -550,6 +556,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -550,6 +556,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
else else
{ {
// Override data points which are located outside [0, seqlen_k) to 0.0
if constexpr(kIsPagedKV && kPadSeqLenK)
{
if(v_page_block_navigator.is_last_block(i_page_block_v))
{
const auto v_origin = v_page_block_navigator.to_global_window_origin(
i_page_block_v, v_dram_window.get_window_origin());
set_tile_if(
v_prefetch,
type_convert<VDataType>(0.0),
[&, physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col =
v_origin.at(number<1>{}) + tile_idx.at(number<1>{});
return physical_seqlen_k_end_ <= col;
});
}
}
store_tile(v_lds_window, store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
} }
...@@ -565,7 +589,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -565,7 +589,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static_for<0, k1_loops - 1, 1>{}([&, static_for<0, k1_loops - 1, 1>{}([&,
&i_page_block_v_ = i_page_block_v, &i_page_block_v_ = i_page_block_v,
&v_dram_window_ = v_dram_window](auto i_k1) { &v_dram_window_ = v_dram_window](auto i_k1) {
const auto v = load_tile(v_dram_window_); // load next v auto v = load_tile(v_dram_window_); // load next v
block_sync_lds(); block_sync_lds();
gemm_1(o_acc, gemm_1(o_acc,
get_slice_tile( get_slice_tile(
...@@ -583,6 +607,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -583,6 +607,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
else else
{ {
// Override data points which are located outside [0, seqlen_k) to 0.0
if constexpr(kIsPagedKV && kPadSeqLenK)
{
if(v_page_block_navigator.is_last_block(i_page_block_v_))
{
const auto v_origin =
v_page_block_navigator.to_global_window_origin(
i_page_block_v_, v_dram_window_.get_window_origin());
set_tile_if(v,
type_convert<VDataType>(0.0),
[&, physical_seqlen_k_end_ = physical_seqlen_k_end](
auto tile_idx) {
const auto col = v_origin.at(number<1>{}) +
tile_idx.at(number<1>{});
return physical_seqlen_k_end_ <= col;
});
}
}
store_tile(v_lds_window, store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v tile_elementwise_in(v_element_func, v)); // store next v
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment