Commit de6dd79f authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Fix compilation errors

parent 232864b4
...@@ -24,7 +24,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -24,7 +24,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>; using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>; using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
...@@ -57,7 +56,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -57,7 +56,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -69,7 +69,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -69,7 +69,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
else else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>(); return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}(); }();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>(); static constexpr index_t kAlignmentOacc =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
static constexpr index_t kAlignmentBias = static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>(); kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
...@@ -83,7 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -83,7 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
else else
{ {
// minimize occupancy // minimize occupancy
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS)
{ {
return 1; return 1;
} }
...@@ -119,24 +121,23 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -119,24 +121,23 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
static constexpr const char* name = "qr_async"; static constexpr const char* name = "qr_async";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowLengths,
typename VDramBlockWindowTmp, typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename QElementFunction, typename QElementFunction,
typename KElementFunction, typename KElementFunction,
typename VElementFunction, typename VElementFunction,
typename BiasElementFunction, typename BiasElementFunction,
typename LSEElementFunction, typename LSEaccElementFunction,
typename SAccElementFunction, typename SAccElementFunction,
typename PComputeElementFunction, typename PComputeElementFunction,
typename OAccElementFunction, typename OAccElementFunction,
...@@ -144,35 +145,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -144,35 +145,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func, const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const KElementFunction& /*k_element_func*/, const KElementFunction& /*k_element_func*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func, const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEaccElementFunction& lse_acc_element_func,
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func, const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func, const OAccElementFunction& o_acc_element_func,
index_t num_splits,
index_t i_split,
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
DropoutType& dropout) const void* smem_ptr) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> && std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>, std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
"wrong!"); "wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
...@@ -264,24 +268,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -264,24 +268,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin(); const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] = const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}); q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do // check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK) if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{ {
if(num_total_loop <= 0) const index_t logical_num_total_loop =
integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
if(logical_num_total_loop <= 0)
{ {
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
auto lse = auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution()); make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -numeric<SMPLComputeDataType>::infinity()); set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
} }
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?) // otherwise will have compute error(maybe compiler bug?)
...@@ -292,23 +297,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -292,23 +297,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
} }
auto k_dram_block_window = const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
k_dram_block_window_tmp.get_window_lengths(), // make sure the first tile is completely located in page-block (page-block size should be
{seqlen_k_start, 0}); // divisible by kN0)
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
const index_t aligned_physical_seqlen_k_start =
[&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
if constexpr(kIsPagedKV)
{
return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
}
else
{
return physical_seqlen_k_start_;
}
}();
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window,
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load // load
k_dram_window.init_raw(); k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{}; constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = [&]() { constexpr auto k_pre_np = [&]() {
if constexpr(kPadSeqLenK && if constexpr(kPadSeqLenK && (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || (BiasEnum != BlockAttentionBiasEnum::NO_BIAS)))
(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)))
return bool_constant<true>{}; return bool_constant<true>{};
else else
return bool_constant<false>{}; return bool_constant<false>{};
...@@ -318,16 +338,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -318,16 +338,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
auto bias_dram_window = auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>( auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
randval_dram_block_window_tmp, seqlen_k_start); v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>()); Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile // prefetch K tile
...@@ -438,7 +456,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -438,7 +456,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s; s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col); // position_encoding accept only logical coordinates, do conversion here
position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
}); });
}); });
} }
...@@ -450,9 +469,35 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -450,9 +469,35 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
#endif #endif
} }
move_tile_window(bias_dram_window, {0, kN0}); move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
{
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
}
else
{
return physical_seqlen_k_end_ <= col;
}
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
// mask accept only logical coordinates, do conversion here
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}), k_origin.at(number<0>{}),
number<kM0>{}, number<kM0>{},
...@@ -464,7 +509,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -464,7 +509,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) { s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col); return mask.IsOutOfBound(row, col - kv_l2p_offset);
}); });
} }
} }
...@@ -513,9 +558,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -513,9 +558,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
move_tile_window( i_page_block_v = v_page_block_navigator.move_tile_window(
v_dram_window, i_page_block_v, v_dram_window, {0, kK1});
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile( v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
} }
...@@ -595,17 +639,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -595,17 +639,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
}); });
}); });
if constexpr(kHasDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
const auto p = [&]() { const auto p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>) if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>( return impl::cast_tile_pk_fp16_fp32<PDataType>(
...@@ -618,11 +651,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -618,11 +651,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
// STAGE 3, KV gemm // STAGE 3, KV gemm
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { static_for<0, k1_loops - 1, 1>{}([&,
&i_page_block_v_ = i_page_block_v,
&v_dram_window_ = v_dram_window](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{ {
v_buf = load_tile( v_buf = load_tile(v_dram_window_,
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf number<-1>{},
bool_constant<false>{}); // load next v_buf
} }
block_sync_lds(); block_sync_lds();
gemm_1(o_acc, gemm_1(o_acc,
...@@ -656,14 +692,17 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -656,14 +692,17 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
} }
if constexpr(i_k1 < k1_loops - 1) if constexpr(i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1}); i_page_block_v_ = v_page_block_navigator.move_tile_window(
i_page_block_v_, v_dram_window_, {0, kK1});
}); });
} }
i_total_loops++; i_total_loops++;
if(i_total_loops < num_total_loop) if(i_total_loops < num_total_loop)
{ {
// move K tile windows // move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0}); i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(k1_loops >= 2 && if constexpr(k1_loops >= 2 &&
...@@ -689,30 +728,30 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -689,30 +728,30 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
} }
} while(i_total_loops < num_total_loop); } while(i_total_loops < num_total_loop);
// store lse // store lse acc
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution()); auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 #if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI) BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); lse_acc(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
} }
else else
{ {
lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); lse_acc(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]);
} }
#else #else
lse(i_idx) = m_[i_idx] + log(l_[i_idx]); lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
#endif #endif
}); });
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc));
} }
// finally, O // finally, O
...@@ -740,44 +779,51 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -740,44 +779,51 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowLengths,
typename VDramBlockWindowTmp, typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const KPageBlockNavigator& k_page_block_navigator,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile index_t num_splits,
index_t i_split,
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
DropoutType& dropout) const void* smem_ptr) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
k_dram_block_window_tmp, k_dram_block_window_lengths,
k_page_block_navigator,
identity{}, identity{},
v_dram_block_window_tmp, v_dram_block_window_lengths,
v_page_block_navigator,
identity{}, identity{},
bias_dram_block_window_tmp, bias_dram_block_window_tmp,
identity{}, identity{},
randval_dram_block_window_tmp, lse_acc_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{}, identity{},
identity{}, identity{},
identity{}, identity{},
identity{}, identity{},
num_splits,
i_split,
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
smem_ptr, kv_l2p_offset,
dropout); smem_ptr);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment