Unverified Commit cf2d635e authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] Fix incorrect computation of group mode PagedAttention (#1688)



* Allow getting batch size from splitkv tile partitioner

* Fix wrong paged-kvcache impl for group mode

* Fix wrong example code for page-kvcache

* Undo changes in fmha_fwd.cpp

* Always use 2D block table

* Add is_gappy kernel argument for paged-kvcache

The is_gappy argument is used for differentiating seqstart_k_ptr usage
in flash-attention & xformers

* Remove out-of-date comments

* Remove no-longer used method

* Fix wrong # page-block calculation

* Fix wrong comment

---------
Co-authored-by: default avatarQianfeng <qianfeng.zhang@amd.com>
parent b6bcd76d
...@@ -1046,6 +1046,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1046,6 +1046,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table; args.batch_stride_block_table = batch_stride_block_table;
args.page_block_size = page_block_size; args.page_block_size = page_block_size;
args.is_gappy = false; // use 'false' for flash-attention integration
args.cache_batch_idx = args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
......
...@@ -165,6 +165,8 @@ struct fmha_fwd_splitkv_args ...@@ -165,6 +165,8 @@ struct fmha_fwd_splitkv_args
void* block_table_ptr; void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
// nullptr.
const void* cache_batch_idx; const void* cache_batch_idx;
...@@ -173,12 +175,21 @@ struct fmha_fwd_splitkv_args ...@@ -173,12 +175,21 @@ struct fmha_fwd_splitkv_args
// seqlen_k = kargs.seqlen_k // seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
//
// batch mode (kvcache): // batch mode (kvcache):
// seqlen_q = kargs.seqlen_q // seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k_ptr[b] // seqlen_k = kargs.seqlen_k_ptr[b]
// group mode (kvcache): // group mode (kvcache):
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
//
// when is_gappy=true:
// seqlen_k = kargs.seqlen_k_ptr[b]
// seqstart_k_ptr[b] now store local offset of each batch
//
// when is_gappy=false:
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
const void* seqstart_k_ptr; const void* seqstart_k_ptr;
const void* seqlen_k_ptr; const void* seqlen_k_ptr;
...@@ -395,6 +406,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -395,6 +406,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.block_table_ptr, args.block_table_ptr,
args.batch_stride_block_table, args.batch_stride_block_table,
args.page_block_size, args.page_block_size,
args.is_gappy,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
args.stride_q, args.stride_q,
......
...@@ -172,13 +172,18 @@ struct FmhaFwdSplitKVKernel ...@@ -172,13 +172,18 @@ struct FmhaFwdSplitKVKernel
float scale_p; float scale_p;
}; };
struct PageBlockTableKargs struct CommonPageBlockTableKargs
{ {
const int32_t* block_table_ptr; const int32_t* block_table_ptr;
ck_tile::index_t batch_stride_block_table; ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size; ck_tile::index_t page_block_size;
}; };
struct GroupModePageBlockTableKargs : CommonPageBlockTableKargs
{
bool is_gappy = false;
};
struct CacheBatchIdxKargs struct CacheBatchIdxKargs
{ {
const int32_t* cache_batch_idx; const int32_t* cache_batch_idx;
...@@ -193,7 +198,7 @@ struct FmhaFwdSplitKVKernel ...@@ -193,7 +198,7 @@ struct FmhaFwdSplitKVKernel
EmptyKargs<0>>>, EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>, std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>, std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs> std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>
{ {
const int32_t* seqlen_k_ptr; const int32_t* seqlen_k_ptr;
...@@ -215,7 +220,7 @@ struct FmhaFwdSplitKVKernel ...@@ -215,7 +220,7 @@ struct FmhaFwdSplitKVKernel
EmptyKargs<0>>>, EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>, std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>, std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kIsPagedKV, PageBlockTableKargs, EmptyKargs<3>> std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>
{ {
const int32_t* seqstart_q_ptr; const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr; const int32_t* seqstart_k_ptr;
...@@ -375,6 +380,7 @@ struct FmhaFwdSplitKVKernel ...@@ -375,6 +380,7 @@ struct FmhaFwdSplitKVKernel
const void* block_table_ptr, const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table, ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size, ck_tile::index_t page_block_size,
bool is_gappy,
float scale_s, float scale_s,
float scale_p, float scale_p,
ck_tile::index_t stride_q, ck_tile::index_t stride_q,
...@@ -461,6 +467,7 @@ struct FmhaFwdSplitKVKernel ...@@ -461,6 +467,7 @@ struct FmhaFwdSplitKVKernel
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr); kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.batch_stride_block_table = batch_stride_block_table; kargs.batch_stride_block_table = batch_stride_block_table;
kargs.page_block_size = page_block_size; kargs.page_block_size = page_block_size;
kargs.is_gappy = is_gappy;
} }
return kargs; return kargs;
...@@ -495,11 +502,13 @@ struct FmhaFwdSplitKVKernel ...@@ -495,11 +502,13 @@ struct FmhaFwdSplitKVKernel
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_q = 0; long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0; long_index_t batch_offset_k = 0; // unused for paged-kvcache
long_index_t batch_offset_v = 0; long_index_t batch_offset_v = 0; // unused for paged-kvcache
long_index_t batch_offset_bias = 0; long_index_t batch_offset_bias = 0;
long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0; long_index_t batch_offset_o_acc = 0;
index_t kv_l2p_offset =
0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
if constexpr(kIsGroupMode) if constexpr(kIsGroupMode)
{ {
...@@ -508,22 +517,14 @@ struct FmhaFwdSplitKVKernel ...@@ -508,22 +517,14 @@ struct FmhaFwdSplitKVKernel
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q; batch_offset_q = query_start * kargs.stride_q;
if constexpr(kIsPagedKV) batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k; batch_offset_v = key_start * kargs.stride_v;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
} }
else else
{ {
batch_offset_k = key_start * kargs.stride_k; batch_offset_v = key_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
}
} }
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -551,6 +552,15 @@ struct FmhaFwdSplitKVKernel ...@@ -551,6 +552,15 @@ struct FmhaFwdSplitKVKernel
{ {
kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch]; kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
} }
if constexpr(kIsPagedKV)
{
if(kargs.is_gappy)
{
// seqstart_k_ptr has different meaning in this case
kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
}
}
} }
else else
{ {
...@@ -703,7 +713,7 @@ struct FmhaFwdSplitKVKernel ...@@ -703,7 +713,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) + reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table; i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks = const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset = const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) * static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
...@@ -718,7 +728,8 @@ struct FmhaFwdSplitKVKernel ...@@ -718,7 +728,8 @@ struct FmhaFwdSplitKVKernel
kargs.page_block_size, kargs.page_block_size,
k_dram, k_dram,
make_k_dram(nullptr, make_k_dram(nullptr,
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size)); (kv_l2p_offset + kargs.seqlen_k) -
(num_blocks - 1) * kargs.page_block_size));
} }
else else
{ {
...@@ -733,7 +744,7 @@ struct FmhaFwdSplitKVKernel ...@@ -733,7 +744,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) + reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table; i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks = const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size); integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset = const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) * static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
...@@ -748,7 +759,8 @@ struct FmhaFwdSplitKVKernel ...@@ -748,7 +759,8 @@ struct FmhaFwdSplitKVKernel
kargs.page_block_size, kargs.page_block_size,
v_dram, v_dram,
make_v_dram(nullptr, make_v_dram(nullptr,
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size)); (kv_l2p_offset + kargs.seqlen_k) -
(num_blocks - 1) * kargs.page_block_size));
} }
else else
{ {
...@@ -896,6 +908,7 @@ struct FmhaFwdSplitKVKernel ...@@ -896,6 +908,7 @@ struct FmhaFwdSplitKVKernel
mask, mask,
position_encoding, position_encoding,
kargs.scale_s, kargs.scale_s,
kv_l2p_offset,
smem_ptr); smem_ptr);
} }
else else
...@@ -912,6 +925,7 @@ struct FmhaFwdSplitKVKernel ...@@ -912,6 +925,7 @@ struct FmhaFwdSplitKVKernel
mask, mask,
position_encoding, position_encoding,
kargs.scale_s, kargs.scale_s,
kv_l2p_offset,
smem_ptr); smem_ptr);
} }
}(); }();
......
...@@ -18,11 +18,11 @@ struct FmhaFwdSplitKVTilePartitioner ...@@ -18,11 +18,11 @@ struct FmhaFwdSplitKVTilePartitioner
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits) ck_tile::index_t num_splits)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) * return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
......
...@@ -143,6 +143,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -143,6 +143,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const void* smem_ptr) const
{ {
static_assert( static_assert(
...@@ -211,16 +212,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -211,16 +212,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
set_tile(m, -numeric<SMPLComputeDataType>::infinity()); set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l); clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin(); const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split); q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
// check early exit if no work to do // check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{ {
const index_t original_num_total_loop = const index_t logical_num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
if(original_num_total_loop <= 0) if(logical_num_total_loop <= 0)
{ {
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
...@@ -239,33 +240,41 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -239,33 +240,41 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
} }
// make sure the first tile is completely located in page-block const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] { const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
if constexpr(kIsPagedKV) // make sure the first tile is completely located in page-block (page-block size should be
{ // divisible by kN0)
return kN0 * integer_divide_floor(seqlen_k_start_, kN0); // relationship between each *_start variables: aligned_physical_seqlen_k_start <=
} // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
else const index_t aligned_physical_seqlen_k_start =
{ [&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
return seqlen_k_start_; if constexpr(kIsPagedKV)
} {
}(); return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
}
else
{
return physical_seqlen_k_start_;
}
}();
const index_t num_total_loop = const index_t num_total_loop =
integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0); integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0}); k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N {bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths, v_dram_block_window_lengths,
{0, adjusted_seqlen_k_start}, // TODO: hdim split? {0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>()); Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q); auto q_tile = tile_elementwise_in(q_element_func, q);
...@@ -379,7 +388,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -379,7 +388,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s; s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col); // position_encoding accept only logical coordinates, do conversion here
position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
}); });
}); });
} }
...@@ -397,29 +407,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -397,29 +407,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{ {
const auto k_origin = k_page_block_navigator.to_global_window_origin( const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin()); i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(s_acc, set_tile_if(
-numeric<SMPLComputeDataType>::infinity(), s_acc,
[&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end]( -numeric<SMPLComputeDataType>::infinity(),
auto tile_idx) { [&,
const auto col = physical_seqlen_k_start_ = physical_seqlen_k_start,
k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
if constexpr(kIsPagedKV) const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
{ if constexpr(kIsPagedKV)
return col < seqlen_k_start_ || seqlen_k_end_ <= col; {
} return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
else }
{ else
return seqlen_k_end_ <= col; {
} return physical_seqlen_k_end_ <= col;
}); }
});
} }
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{ {
const auto k_origin = k_page_block_navigator.to_global_window_origin( const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin()); i_page_block_k, k_dram_block_window.get_window_origin());
// mask accept only logical coordinates, do conversion here
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}), k_origin.at(number<0>{}) - kv_l2p_offset,
number<kM0>{}, number<kM0>{},
number<kN0>{}); number<kN0>{});
if(need_perpixel_check) if(need_perpixel_check)
...@@ -428,7 +440,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -428,7 +440,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) { s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col); return mask.IsOutOfBound(row, col - kv_l2p_offset);
}); });
} }
} }
...@@ -659,6 +671,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -659,6 +671,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const void* smem_ptr) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
...@@ -681,6 +694,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -681,6 +694,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
kv_l2p_offset,
smem_ptr); smem_ptr);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment