Unverified Commit 74d68e3b authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

[CK_TILE] Simplify the codes in splitkv_combine pipeline (#1549)



* Simplify the codes in splitkv_combine pipeline

* Always set kPadSeqLenK=true for fmha splitkv kernels

* Change in Oacc Alignment and TileDistribution to be more adaptable to tile sizes

---------
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
parent 7733ae16
......@@ -600,8 +600,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
# TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]:
# if True:
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
......
......@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
static const auto get_validated_m = [](LSEDataType raw_m) {
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_m;
};
decltype(lse_accum) lse_exp;
{
constexpr auto spans = decltype(lse_exp)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_max[i_idx] == -numeric<LSEDataType>::infinity())
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) =
ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx)));
lse_exp(i_j_idx) = ck_tile::type_convert<LSEDataType>(0.0f);
});
}
else
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - lse_max(i_idx));
});
}
});
}
......@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx))
{
lse_logsum(i_idx) = numeric<LSEDataType>::infinity();
}
if(lse_sum[i_idx] == ck_tile::type_convert<LSEDataType>(0.0f))
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
else
{
lse_logsum(i_idx) =
ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx));
}
lse_logsum(i_idx) = ck_tile::log(lse_sum(i_idx)) + lse_max(i_idx);
});
}
......@@ -218,6 +218,25 @@ struct BlockFmhaFwdSplitKVCombinePipeline
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_logsum(i_idx) == -numeric<LSEDataType>::infinity())
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{});
if(col < num_splits)
{
const auto row = x_indices.at(number<0>{});
lse_acc_lds(row, col) = ck_tile::type_convert<LSEDataType>(0.0f);
}
});
}
else
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
......@@ -233,22 +252,13 @@ struct BlockFmhaFwdSplitKVCombinePipeline
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
}
});
}
});
}
block_sync_lds();
if constexpr(kStoreLSE)
{
constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_logsum(i_idx) == numeric<LSEDataType>::infinity())
{
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
}
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
}
......
......@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return 16 / sizeof(OaccDataType);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
return min(N1, static_cast<index_t>(16 / sizeof(OaccDataType)));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
using ODataType = remove_cvref_t<typename Problem::ODataType>;
return 16 / sizeof(ODataType);
return GetAlignmentOacc<Problem>();
}
template <typename Problem>
......@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t N1 = 16 / sizeof(OaccDataType);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = get_warp_size() / N0;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment