Unverified Commit 74d68e3b authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

[CK_TILE] Simplify the codes in splitkv_combine pipeline (#1549)



* Simplify the codes in splitkv_combine pipeline

* Always set kPadSeqLenK=true for fmha splitkv kernels

* Change in Oacc Alignment and TileDistribution to be more adaptable to tile sizes

---------
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
parent 7733ae16
...@@ -600,8 +600,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -600,8 +600,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: if hdim == 256 or hdim in [32, 64, 128]:
# if True: # if True:
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
......
...@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity()); lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{}); block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
static const auto get_validated_m = [](LSEDataType raw_m) {
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_m;
};
decltype(lse_accum) lse_exp; decltype(lse_accum) lse_exp;
{ {
constexpr auto spans = decltype(lse_exp)::get_distributed_spans(); constexpr auto spans = decltype(lse_exp)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(spans[number<1>{}], [&](auto idx1) { if(lse_max[i_idx] == -numeric<LSEDataType>::infinity())
constexpr auto i_j_idx = make_tuple(idx0, idx1); {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) = lse_exp(i_j_idx) = ck_tile::type_convert<LSEDataType>(0.0f);
ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx))); });
}); }
else
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - lse_max(i_idx));
});
}
}); });
} }
...@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx)) if(lse_sum[i_idx] == ck_tile::type_convert<LSEDataType>(0.0f))
{ lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
lse_logsum(i_idx) = numeric<LSEDataType>::infinity();
}
else else
{ lse_logsum(i_idx) = ck_tile::log(lse_sum(i_idx)) + lse_max(i_idx);
lse_logsum(i_idx) =
ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx));
}
}); });
} }
...@@ -218,37 +218,47 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -218,37 +218,47 @@ struct BlockFmhaFwdSplitKVCombinePipeline
constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(spans[number<1>{}], [&](auto idx1) { if(lse_logsum(i_idx) == -numeric<LSEDataType>::infinity())
constexpr auto i_j_idx = make_tuple(idx0, idx1); {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices( const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx); lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{}); const auto col = x_indices.at(number<1>{});
if(col < num_splits) if(col < num_splits)
{ {
const auto row = x_indices.at(number<0>{}); const auto row = x_indices.at(number<0>{});
lse_acc_lds(row, col) = lse_acc_lds(row, col) = ck_tile::type_convert<LSEDataType>(0.0f);
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx)); }
} });
}); }
else
{
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{});
if(col < num_splits)
{
const auto row = x_indices.at(number<0>{});
lse_acc_lds(row, col) =
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
}
});
}
}); });
} }
block_sync_lds(); block_sync_lds();
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_logsum(i_idx) == numeric<LSEDataType>::infinity())
{
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
}
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
} }
......
...@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{ {
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return 16 / sizeof(OaccDataType);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
return min(N1, static_cast<index_t>(16 / sizeof(OaccDataType)));
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{ {
using ODataType = remove_cvref_t<typename Problem::ODataType>; return GetAlignmentOacc<Problem>();
return 16 / sizeof(ODataType);
} }
template <typename Problem> template <typename Problem>
...@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
{ {
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1; constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t N1 = 16 / sizeof(OaccDataType);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = get_warp_size() / N0;
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
constexpr index_t M0 = kMPerBlock / (M2 * M1); constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution( return make_static_tile_distribution(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment