"docs/vscode:/vscode.git/clone" did not exist on "111f33be5aab4cae34e4f5601a2069421dd2c8e9"
Commit 400cac28 authored by coderfeli's avatar coderfeli
Browse files

Merge branch 'develop' of https://github.com/ROCm/composable_kernel into update_cka8w8

parents 7cec63a6 4c2eff02
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
......
...@@ -71,7 +71,8 @@ struct FmhaFwdKernel ...@@ -71,7 +71,8 @@ struct FmhaFwdKernel
using bfs = typename FmhaPipeline::BlockFmhaShape; using bfs = typename FmhaPipeline::BlockFmhaShape;
using g0br = typename bfs::Gemm0BlockWarps; using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps; using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile; using g0wt = typename bfs::Gemm0WarpTile;
using g1wt = typename bfs::Gemm1WarpTile;
#define _SS_ std::string #define _SS_ std::string
#define _TS_ std::to_string #define _TS_ std::to_string
auto pn = [&] () { auto pn = [&] () {
...@@ -88,7 +89,8 @@ struct FmhaFwdKernel ...@@ -88,7 +89,8 @@ struct FmhaFwdKernel
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
......
...@@ -8,9 +8,11 @@ namespace ck_tile { ...@@ -8,9 +8,11 @@ namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdSplitKVCombineKernel struct FmhaFwdSplitKVCombineKernel
{ {
using TilePartitioner = remove_cvref_t<TilePartitioner_>; using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = remove_cvref_t<FmhaPipeline_>; using FmhaPipeline = remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps;
static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0); static_assert(kBlockPerCu > 0);
...@@ -50,8 +52,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -50,8 +52,7 @@ struct FmhaFwdSplitKVCombineKernel
return return
_SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) + _SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" "_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(FmhaPipeline::kM0) + "x" + "b" + _TS_(FmhaPipeline::kN1) + "_" +
_TS_(FmhaPipeline::kN1) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) + _SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) + (pn.empty() ? "" : "_" + pn) +
...@@ -339,37 +340,56 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -339,37 +340,56 @@ struct FmhaFwdSplitKVCombineKernel
number<FmhaPipeline::kAlignmentOacc>{}, number<FmhaPipeline::kAlignmentOacc>{},
number<1>{}); number<1>{});
// read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps
const auto o_acc_dram_view = pad_tensor_view( const auto o_acc_dram_view = pad_tensor_view(
o_acc_dram_naive, o_acc_dram_naive,
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}), make_tuple(
sequence<false, kPadSeqLenQ, kPadHeadDimV>{}); number<kNumWarps>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<true, kPadSeqLenQ, kPadHeadDimV>{});
const index_t padded_num_splits =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<0>{}];
const index_t padded_seqlen_q = const index_t padded_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v = const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view( const index_t num_m_tiles = integer_divide_floor(padded_seqlen_q, FmhaPipeline::kM0);
// transform tensor view by following steps, given shape: (padded_num_splits,
// padded_seqlen_q, padded_hdim_v)
// 1. unmerge to (padded_num_splits, num_m_tiles, kM0, padded_hdim_v)
// 2. transpose to (num_m_tiles, padded_num_splits, kM0, padded_hdim_v)
// 3. merge to (num_m_tiles * padded_num_splits * kM0, padded_hdim_v)
auto transposed = transform_tensor_view(
o_acc_dram_view, o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)), make_tuple(make_pass_through_transform(padded_num_splits),
make_unmerge_transform(make_tuple(num_m_tiles, FmhaPipeline::kM0)),
make_pass_through_transform(padded_hdim_v)), make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<1>{}, sequence<0, 2>{}, sequence<3>{}));
return transform_tensor_view(
transposed,
make_tuple(make_merge_transform(
make_tuple(num_m_tiles, padded_num_splits, FmhaPipeline::kM0)),
make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
}(); }();
auto lse_acc_dram_window = make_tile_window( auto lse_acc_dram_window = make_tile_window(
lse_acc_dram, lse_acc_dram,
[&]() { make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{}),
return make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{});
}(),
{0, i_m0}); {0, i_m0});
const index_t padded_num_splits =
integer_divide_ceil(kargs.num_splits, kNumWarps) * kNumWarps;
auto o_acc_dram_window = make_tile_window( auto o_acc_dram_window = make_tile_window(
o_acc_dram, o_acc_dram,
[&]() { make_tuple(number<kNumWarps * FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}); {i_tile_m * padded_num_splits * FmhaPipeline::kM0, i_n1});
}(),
{i_m0, i_n1});
// LSE DRAM window // LSE DRAM window
auto lse_dram_window = [&, i_nhead_ = i_nhead]() { auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
...@@ -410,7 +430,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -410,7 +430,6 @@ struct FmhaFwdSplitKVCombineKernel
identity{}, // lse_element_func identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits, kargs.num_splits,
kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
else else
...@@ -419,7 +438,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -419,7 +438,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window, o_acc_dram_window,
lse_dram_window, lse_dram_window,
kargs.num_splits, kargs.num_splits,
kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
}(); }();
......
...@@ -45,6 +45,7 @@ struct FmhaFwdSplitKVKernel ...@@ -45,6 +45,7 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
...@@ -67,7 +68,8 @@ struct FmhaFwdSplitKVKernel ...@@ -67,7 +68,8 @@ struct FmhaFwdSplitKVKernel
using bfs = typename FmhaPipeline::BlockFmhaShape; using bfs = typename FmhaPipeline::BlockFmhaShape;
using g0br = typename bfs::Gemm0BlockWarps; using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps; using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile; using g0wt = typename bfs::Gemm0WarpTile;
using g1wt = typename bfs::Gemm1WarpTile;
#define _SS_ std::string #define _SS_ std::string
#define _TS_ std::to_string #define _TS_ std::to_string
auto pn = [&] () { auto pn = [&] () {
...@@ -84,11 +86,12 @@ struct FmhaFwdSplitKVKernel ...@@ -84,11 +86,12 @@ struct FmhaFwdSplitKVKernel
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
// clang-format on // clang-format on
......
...@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr index_t kNumWarps = Problem::kNumWarps;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kHeadDimV = Problem::kHeadDimV; static constexpr index_t kHeadDimV = Problem::kHeadDimV;
...@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func, const OaccElementFunction& o_acc_element_func,
index_t num_splits, index_t num_splits,
index_t seqlen_q,
void* smem_ptr) const void* smem_ptr) const
{ {
// lse_acc tile in LDS // lse_acc tile in LDS
...@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]). // copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
auto lse_acc_tile = load_tile(lse_acc_dram_window); auto lse_acc_tile = load_tile(lse_acc_dram_window);
store_tile(lse_acc_lds_write_window, lse_acc_tile); store_tile(lse_acc_lds_write_window, lse_acc_tile);
block_sync_lds();
auto lse_accum = make_static_distributed_tensor<LSEDataType>( auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccRegTileDistribution<Problem>()); Policy::template MakeLSEaccRegTileDistribution<Problem>());
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits]) // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region. // and fill up -INF values outside the [kM0, num_splits] region.
{ {
...@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline
} }
}); });
} }
block_sync_lds();
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
} }
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>(); auto o_acc_4_dist = Policy::template MakeOacc4DramTileDistribution<Problem>();
auto o_acc_dram_window = auto o_acc_4_dram_window =
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
o_acc_dram_block_window_tmp.get_window_lengths(), o_acc_dram_block_window_tmp.get_window_lengths(),
o_acc_dram_block_window_tmp.get_window_origin(), o_acc_dram_block_window_tmp.get_window_origin(),
o_acc_dist); o_acc_4_dist);
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0; // shape=[4 * KM0, kN1]
auto o_acc_4 = make_static_distributed_tensor<OaccDataType>(o_acc_4_dist);
clear_tile(o_acc_4);
for(index_t i_split = 0; i_split < num_splits; ++i_split) const index_t padded_num_splits = integer_divide_ceil(num_splits, kNumWarps) * kNumWarps;
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// each warp handles a [KM0, kN1] tile
for(index_t split_start = 0; split_start < padded_num_splits; split_start += kNumWarps)
{ {
auto o_tile = load_tile(o_acc_dram_window); auto o_tile = load_tile(o_acc_4_dram_window);
const index_t i_split = split_start + get_warp_id();
const index_t row_start = kM0 * get_warp_id();
{ {
constexpr auto spans = decltype(o_acc)::get_distributed_spans(); constexpr auto spans = decltype(o_acc_4)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) { sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices( const auto x_indices = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx); o_acc_4.get_tile_distribution(), i_j_idx);
const auto row = x_indices.at(number<0>{}); const auto row = x_indices.at(number<0>{});
const LSEDataType lse_scale = lse_acc_lds(row, i_split); const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split);
o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx); o_acc_4(i_j_idx) += lse_scale * o_tile(i_j_idx);
}); });
}); });
} }
move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0}); move_tile_window(o_acc_4_dram_window, {kNumWarps * kM0, 0});
}
// 4 o_acc tiles in LDS. shape=[4 * kM0, kN1]
OaccDataType* o_acc_4_lds_ptr = static_cast<OaccDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeLSEacc<Problem>()));
{
auto o_acc_4_lds_window = [&]() {
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
return make_tile_window(view, desc.get_lengths(), {0, 0});
}();
store_tile(o_acc_4_lds_window, o_acc_4);
} }
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
auto o_acc_4_lds_window = [&]() {
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_dist);
}();
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
static_for<0, kNumWarps, 1>{}([&](auto) {
auto o_acc_in = load_tile(o_acc_4_lds_window);
{
constexpr auto spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) += o_acc_in(i_j_idx);
});
});
}
move_tile_window(o_acc_4_lds_window, {kM0, 0});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc); o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc; return o_acc;
...@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const OaccDramBlockWindow& o_acc_dram_block_window, const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window, LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits, index_t num_splits,
index_t seqlen_q,
void* smem_ptr) const void* smem_ptr) const
{ {
return operator()(lse_acc_dram_block_window, return operator()(lse_acc_dram_block_window,
...@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity{}, identity{},
identity{}, identity{},
num_splits, num_splits,
seqlen_q,
smem_ptr); smem_ptr);
} }
}; };
......
...@@ -10,23 +10,38 @@ namespace ck_tile { ...@@ -10,23 +10,38 @@ namespace ck_tile {
struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{ {
template <index_t BlockSize, index_t M, index_t N, typename DataType> template <index_t NumWarps, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetMaxNumWarpsForTile()
{
static_assert(NumWarps == 1 || NumWarps == 2 || NumWarps == 4);
constexpr index_t ElemPerThread = (M * N) / (NumWarps * get_warp_size());
if constexpr(0 < ElemPerThread)
{
return NumWarps;
}
else
{ // try dividing tile by smaller # of warps
return GetMaxNumWarpsForTile<NumWarps / 2, M, N, DataType>();
}
}
template <index_t NumWarps, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile() CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile()
{ {
constexpr index_t PixelsPerThread = (M * N) / BlockSize; constexpr index_t MaxNumWarps = GetMaxNumWarpsForTile<NumWarps, M, N, DataType>();
static_assert(0 < PixelsPerThread);
constexpr index_t MaxNPerThread = 16 / sizeof(DataType); constexpr index_t ElemPerThread = (M * N) / (MaxNumWarps * get_warp_size());
constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread);
return NPerThread; constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
return min(MaxNPerThread, ElemPerThread);
} }
// alignment for dram lse tile (shape=[kMaxSplits, kM0]) // alignment for dram lse tile (shape=[kMaxSplits, kM0])
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
{ {
return GetVectorSizeForTile<Problem::kBlockSize, return GetVectorSizeForTile<Problem::kNumWarps,
Problem::kMaxSplits, Problem::kMaxSplits,
Problem::kM0, Problem::kM0,
typename Problem::LSEDataType>(); typename Problem::LSEDataType>();
...@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeLSEacc()
{ {
return sizeof(typename Problem::LSEDataType) * return sizeof(typename Problem::LSEDataType) *
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size(); MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc4()
{
return sizeof(typename Problem::OaccDataType) *
MakeOacc4LdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return GetSmemSizeLSEacc<Problem>() + GetSmemSizeOacc4<Problem>();
}
// shape=[kMaxSplits, kM0] // shape=[kMaxSplits, kM0]
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNumWarps = Problem::kNumWarps;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t MaxNumWarps =
GetMaxNumWarpsForTile<Problem::kNumWarps, kNPerBlock, kMPerBlock, LSEDataType>();
constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
constexpr index_t NPerThread = constexpr index_t NPerThread =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>(); GetVectorSizeForTile<MaxNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp); constexpr index_t MPerThread = kMPerBlock / (MaxNumWarps * MThreadsPerWarp);
static_assert(MPerThread * MaxNumWarps * MThreadsPerWarp == kMPerBlock);
static_assert(NThreads * NPerThread == kNPerBlock); static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<Replicate>,
tuple<sequence<MPerThread, kNumWarps, MThreadsPerWarp>, tuple<sequence<MPerThread, MaxNumWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0, 1>, sequence<2, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
...@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>(); GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}), make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}), make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{}, number<NPack>{},
number<1>{}); number<1>{});
constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor( constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
...@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>(); GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}), make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}), make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{}, number<NPack>{},
number<1>{}); number<1>{});
constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor( constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
...@@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return lse_acc_t_lds_block_desc; return lse_acc_t_lds_block_desc;
} }
// 3d + padding, shape=[4 * kM0, kN1]
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4LdsBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kNPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = 4 * Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t NPack =
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto o_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto o_acc_t_lds_block_desc = transform_tensor_descriptor(
o_acc_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return o_acc_t_lds_block_desc;
}
// shape=[kM0, kMaxSplits]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution()
{
constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t NThreads = 4; constexpr index_t MaxNThreads = 8;
constexpr index_t NPerThread = kNPerBlock / NThreads; constexpr index_t NThreads = min(kNPerBlock, MaxNThreads);
constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MThreads = kBlockSize / NThreads; constexpr index_t MPerThread = 1;
constexpr index_t MPerThread = kMPerBlock / MThreads; constexpr index_t MThreads = kMPerBlock / MPerThread;
constexpr index_t MWarps = kBlockSize / get_warp_size();
constexpr index_t MThreadPerWarp = get_warp_size() / NThreads; constexpr index_t MThreadPerWarp = get_warp_size() / NThreads;
constexpr index_t MaxNumWarps = (MThreads * NThreads) / get_warp_size();
constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
static_assert(MaxNumWarps * MThreadPerWarp * MPerThread == kMPerBlock);
static_assert(NThreads * NPerThread == kNPerBlock); static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<sequence<Replicate>,
sequence<1>, tuple<sequence<MaxNumWarps, MThreadPerWarp, MPerThread>,
tuple<sequence<MWarps, MThreadPerWarp, MPerThread>, sequence<NThreads, NPerThread>>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2, 1>>, tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 1>>, tuple<sequence<0, 0>, sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<2, 1>>{}); sequence<2, 1>>{});
}
// similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M
// direction
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4DramTileDistribution()
{
constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (4 * kM0)
constexpr index_t kNPerBlock = Problem::kN1;
static_assert(get_warp_size() <= kMPerBlock * kNPerBlock);
constexpr index_t M1 = 1; // compose encoding base on 1 warp
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<4, M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1, 1>, sequence<1, 2>>,
tuple<sequence<0, 2>, sequence<3, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
} }
template <typename Problem> template <typename Problem>
...@@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1; constexpr index_t kNPerBlock = Problem::kN1;
static_assert(kBlockSize <= kMPerBlock * kNPerBlock);
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size()); constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>
{
using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return static_cast<index_t>(16 / sizeof(OaccDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
return BasePolicy::template MakeQDramTileDistribution<Problem, BlockGemm>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
// TODO: this is for 3d layout
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return static_cast<index_t>(16 / sizeof(QDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kKPack = min(ElemPerThread, GetSmemKPackQ<Problem>());
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS()
{
using SDataType = remove_cvref_t<typename Problem::SaccDataType>;
return static_cast<index_t>(16 / sizeof(SDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kNPack = GetSmemNPackS<Problem>();
constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / kNPack>{}, number<kMPerBlock>{}, number<kNPack>{}),
make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number<kNPack>{}, number<1>{}),
number<kNPack>{},
number<1>{});
constexpr auto s_lds_block_desc = transform_tensor_descriptor(
s_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kNPerBlock / kNPack>{}, number<kNPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return s_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSRegTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == 1, "Check failed!");
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kTileK = Problem::BlockFmhaShape::kN0;
// K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = kKPerBlock / (K2 * K3);
constexpr index_t K0 = kTileK / kKPerBlock;
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
constexpr auto s2_block_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2, K3>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<2, 2>>,
sequence<1, 2, 2, 2>,
sequence<0, 0, 1, 3>>{};
constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding);
return s2_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::QDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::KDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS()
{
return MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::SaccDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>()) +
max(GetSmemSizeV<Problem>(), GetSmemSizeS<Problem>());
}
};
} // namespace ck_tile
...@@ -43,8 +43,6 @@ struct TileFmhaShape ...@@ -43,8 +43,6 @@ struct TileFmhaShape
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static_assert(std::is_same_v<Gemm0WarpTile, Gemm1WarpTile>);
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
......
...@@ -22,7 +22,8 @@ template <bool IsGateOnly_, ...@@ -22,7 +22,8 @@ template <bool IsGateOnly_,
FusedMoeGemmWeightPermuteEnum PermuteEnum_ = FusedMoeGemmWeightPermuteEnum PermuteEnum_ =
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten, FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten,
bool PadHiddenSize_ = false, bool PadHiddenSize_ = false,
bool PadIntermediateSize_ = false> bool PadIntermediateSize_ = false,
bool PipeInterleave_ = true>
struct FusedMoeGemmTraits struct FusedMoeGemmTraits
{ {
// Gate+Up or Gate only // Gate+Up or Gate only
...@@ -32,6 +33,7 @@ struct FusedMoeGemmTraits ...@@ -32,6 +33,7 @@ struct FusedMoeGemmTraits
static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_; static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_;
static constexpr bool PadHiddenSize = PadHiddenSize_; static constexpr bool PadHiddenSize = PadHiddenSize_;
static constexpr bool PadIntermediateSize = PadIntermediateSize_; static constexpr bool PadIntermediateSize = PadIntermediateSize_;
static constexpr bool PipeInterleave = PipeInterleave_;
}; };
// Note: this need to be a bit mask // Note: this need to be a bit mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment