"sims/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "8fec547470c1f629651b402d1780ca2c067faa6a"
Commit 199f7f71 authored by carlushuang's avatar carlushuang
Browse files

modify moe

parent 33ceea62
...@@ -3,16 +3,8 @@ ...@@ -3,16 +3,8 @@
#pragma once #pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
enum class FusedMoeWeightPermuteEnum
enum class FusedMoePermuteStyle
{ {
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6 // permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6 // permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
...@@ -20,14 +12,4 @@ enum class FusedMoePermuteStyle ...@@ -20,14 +12,4 @@ enum class FusedMoePermuteStyle
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
no_permute = 999, no_permute = 999,
}; };
}
template <bool DownPreShuffled_ = false,
FusedMoePermuteStyle PermuteStyle_ = FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct FusedMoeTraits
{
static constexpr bool DownPreShuffled = DownPreShuffled_;
static constexpr FusedMoePermuteStyle PermuteStyle = PermuteStyle_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile
...@@ -14,7 +14,7 @@ namespace ck_tile { ...@@ -14,7 +14,7 @@ namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
struct BlockFmhaPipelineQRKSVSAsync struct FusedMoePipeline
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
...@@ -27,43 +27,49 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -27,43 +27,49 @@ struct BlockFmhaPipelineQRKSVSAsync
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ScaleDataType = remove_cvref_t<typename Problem::ScaleDataType>; using ScaleDataType = remove_cvref_t<typename Problem::ScaleDataType>;
using FusedMoeTileShape = remove_cvref_t<typename Problem::FusedMoeTileShape>; using FusedMoeTileShape = remove_cvref_t<typename Problem::FusedMoeTileShape>;
using VLayout = remove_cvref_t<typename FusedMoeTileShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr index_t kBlockM_0 = FusedMoeTileShape::kBlockM_0;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) static constexpr index_t kBlockN_0 = FusedMoeTileShape::kBlockN_0;
// only need special care about seq_k padding (oob need set -INF of p instead of zero) static constexpr index_t kBlockK_0 = FusedMoeTileShape::kBlockK_0;
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && static constexpr index_t kWarpM_0 = FusedMoeTileShape::kWarpM_0;
Problem::kPadHeadDimV == true); static constexpr index_t kWarpN_0 = FusedMoeTileShape::kWarpN_0;
static constexpr bool kPadSeqLenQ = true; static constexpr index_t kWarpK_0 = FusedMoeTileShape::kWarpK_0;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr index_t kBlockWarpsM_0 = FusedMoeTileShape::kBlockWarpsM_0;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) static constexpr index_t kBlockWarpsN_0 = FusedMoeTileShape::kBlockWarpsN_0;
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr index_t kBlockWarpsK_0 = FusedMoeTileShape::kBlockWarpsK_0;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr index_t kSubBlockM_0 = FusedMoeTileShape::kSubBlockM_0;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr index_t kSubBlockN_0 = FusedMoeTileShape::kSubBlockN_0;
static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr index_t kSubBlockK_0 = FusedMoeTileShape::kSubBlockK_0;
static constexpr index_t kWarpRepeatM_0 = FusedMoeTileShape::kWarpRepeatM_0;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) static constexpr index_t kWarpRepeatN_0 = FusedMoeTileShape::kWarpRepeatN_0;
// ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kWarpRepeatK_0 = FusedMoeTileShape::kWarpRepeatK_0;
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>(); static constexpr index_t kBlockM_1 = FusedMoeTileShape::kBlockM_1;
static constexpr index_t kAlignmentV = []() { static constexpr index_t kBlockN_1 = FusedMoeTileShape::kBlockN_1;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) static constexpr index_t kBlockK_1 = FusedMoeTileShape::kBlockK_1;
return Policy::template GetAlignmentV<Problem>(); static constexpr index_t kWarpM_1 = FusedMoeTileShape::kWarpM_1;
else static constexpr index_t kWarpN_1 = FusedMoeTileShape::kWarpN_1;
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>(); static constexpr index_t kWarpK_1 = FusedMoeTileShape::kWarpK_1;
}(); static constexpr index_t kBlockWarpsM_1 = FusedMoeTileShape::kBlockWarpsM_1;
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>(); static constexpr index_t kBlockWarpsN_1 = FusedMoeTileShape::kBlockWarpsN_1;
static constexpr index_t kAlignmentBias = static constexpr index_t kBlockWarpsK_1 = FusedMoeTileShape::kBlockWarpsK_1;
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>(); static constexpr index_t kSubBlockM_1 = FusedMoeTileShape::kSubBlockM_1;
static constexpr index_t kSubBlockN_1 = FusedMoeTileShape::kSubBlockN_1;
#if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr index_t kSubBlockK_1 = FusedMoeTileShape::kSubBlockK_1;
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>; static constexpr index_t kWarpRepeatM_1 = FusedMoeTileShape::kWarpRepeatM_1;
#endif static constexpr index_t kWarpRepeatN_1 = FusedMoeTileShape::kWarpRepeatN_1;
static constexpr index_t kWarpRepeatK_1 = FusedMoeTileShape::kWarpRepeatK_1;
using MBlockType = decltype(GetMatrixCoreSwizzledBlockTIle_0<Problem>());
static constexpr index_t kBlockNr_0 = MBlockType {}
::at(number<0>{});
static constexpr index_t kBlockKr_0 = MBlockType {}
::at(number<1>{});
static constexpr index_t kBlockWaveFlatten = MBlockType {}
::at(number<2>{});
static constexpr index_t kBlockPerCu = []() { static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1) if constexpr(Problem::kBlockPerCu != -1)
...@@ -71,37 +77,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -71,37 +77,7 @@ struct BlockFmhaPipelineQRKSVSAsync
else else
{ {
// minimize occupancy // minimize occupancy
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) return 2;
{
return 1;
}
if constexpr(kK0BlockLength <= 32)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 64)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2;
else
return 3;
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
{
return 1;
}
} }
}(); }();
...@@ -179,23 +155,261 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -179,23 +155,261 @@ struct BlockFmhaPipelineQRKSVSAsync
o_gtile_window_tmp.get_window_lengths(), o_gtile_window_tmp.get_window_lengths(),
o_gtile_window_tmp.get_window_origin(), o_gtile_window_tmp.get_window_origin(),
Policy::template MakeOGlobalTileDistribution<Problem>()); Policy::template MakeOGlobalTileDistribution<Problem>());
using g_thread_type = decltype(load_tile(g_gtile_window));
using u_thread_type = decltype(load_tile(u_gtile_window));
using d_thread_type = decltype(load_tile(d_gtile_window));
const index_t loops_0 = (dim_size + kBlockK_0 - 1) / kBlockK_0;
const index_t loops_1 = (dim_size + kBlockN_1 - 1) / kBlockN_1;
// auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset;
// issues_warps_lanes
auto a_sst_0 =
make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
// issues_warps_lanes
auto a_sst_1 =
make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
// m*k
auto a_sld_0 = make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0});
// m*k
auto a_sld_1 = make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0});
g_thread_type g_tile[2];
using WarpGemm0 = Policy::GetWarpGemm0<Problem>();
using WarpGemm1 = Policy::GetWarpGemm1<Problem>();
auto warp_gemm_0 = WarpGemm0{};
auto warp_gemm_1 = WarpGemm1{};
// TODO: N fist, M next
const index_t i_mwarp_0 = get_warp_id() / kBlockWarpsN_0;
// create and pre-cache a warp-window
auto make_a_warp_windows = [&](auto a_sld_) {
// construct A-warp-window
auto warp_window = make_tile_window(
a_sld_.get_bottom_tensor_view(),
make_tuple(number<WarpGemm0::kM>{}, number<WarpGemm0::kK>{}),
a_sld_.get_window_origin() + multi_index<2>{i_mwarp_0 * WarpGemm0::kM, 0},
make_static_tile_distribution(typename WarpGemm0::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(warp_window), kWarpRepeatK_0>,
kWarpRepeatM_0>
ws;
// pre-cache the warp windows
static_for<0, kWarpRepeatM_0, 1>{}([&](auto i_m_iter) {
static_for<0, kWarpRepeatK_0, 1>{}([&](auto i_k_iter) {
ws(i_m_iter)(i_k_iter) = warp_window;
move_tile_window(ws(i_m_iter)(i_k_iter),
{i_m_iter * NPerBlockPerIter, i_k_iter * KPerBlockPerIter});
});
});
return ws;
};
auto a_warp_windows_0 = make_a_warp_windows(a_sld_0);
auto a_warp_windows_1 = make_a_warp_windows(a_sld_1);
constexpr auto true_v = bool_constant<true>{};
constexpr auto false_v = bool_constant<false>{};
auto do_load_a0 = [&](auto& a_store_, auto move_) {
async_load_tile(a_store_, a_gtile_window);
if constexpr(move_)
move_tile_window(a_gtile_window, {number<0>{}, number<kBlockK_0>{}});
};
auto do_load_b0 = [&](auto& g_tile_, auto& u_tile_, auto move_) {
g_tile_ = load_tile(g_gtile_window);
u_tile_ = load_tile(u_gtile_window);
if constexpr(move_)
{
move_tile_window(g_gtile_window, {number<0>{}, number<kBlockKr_0>{}, number<0>{}});
move_tile_window(u_gtile_window, {number<0>{}, number<kBlockKr_0>{}, number<0>{}});
}
};
auto do_load_b1 = [&](auto& d_tile_, auto move_) {
d_tile_ = load_tile(d_gtile_window);
if constexpr(move_)
{
move_tile_window(d_gtile_window, {number<0>{}, number<kBlockKr_0>{}, number<0>{}});
}
};
// using AWarpTensor = typename decltype(warp_gemm_0)::AWarpTensor{};
// using CWarpTensor =
auto acc_g = MakeCBlockTile_Gemm0<Problem>();
auto acc_u = MakeCBlockTile_Gemm0<Problem>();
// async_load_tile(a_sst_0, a_gtile_window); move_tile_window(a_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); g_tile[0] = load_tile(g_gtile_window);
// move_tile_window(g_gtile_window, {number<0>{}, number<kBlockK_0>{}}); u_tile[0] =
// load_tile(u_gtile_window); move_tile_window(u_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); async_load_tile(a_sst_1, a_gtile_window);
// move_tile_window(a_gtile_window, {number<0>{}, number<kBlockK_0>{}}); g_tile[1] =
// load_tile(g_gtile_window); move_tile_window(g_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); u_tile[1] = load_tile(u_gtile_window);
// move_tile_window(u_gtile_window, {number<0>{}, number<kBlockK_0>{}});
auto do_gemm_0 =
[&](auto& acc_g_, auto& acc_u_, auto& a_windows_, auto& g_tile_, auto& u_tile_) {
// as_br (asmem, breg)
static_for<0, kWarpRepeatK_0, 1>{}([&](auto i_k) {
static_for<0, kWarpRepeatM_0, 1>{}([&](auto i_m) {
const auto w_a = load_tile(a_windows_(i_m)(i_k));
static_for<0, kWarpRepeatN_0, 1>{}([&](auto i_n) {
constexpr auto beg_acc =
sequence<i_m * kSubBlockM_0, i_n * kSubBlockN_0>{};
constexpr auto end_acc =
sequence<(i_m + 1) * kSubBlockM_0, (i_n + 1) * kSubBlockN_0>{};
// 3d indexing for permuted g/u/d
constexpr auto beg_b =
sequence<i_m * kBlockWarpsM_0, i_n * kSubBlockN_0, 0>{};
constexpr auto end_b =
sequence<(i_m + 1) * kBlockWarpsM_0, (i_n + 1) * kSubBlockN_0, 0>{};
auto w_acc_g = get_slice_tile(acc_g_, beg_acc, end_acc);
auto w_acc_u = get_slice_tile(acc_u_, beg_acc, end_acc);
auto w_g = get_slice_tile(g_tile_, beg_b, end_b);
auto w_u = get_slice_tile(u_tile_, beg_b, end_b);
warp_gemm_0(w_acc_g, w_a, w_g);
warp_gemm_0(w_acc_u, w_a, w_u);
set_slice_tile(acc_g_, w_acc_g, beg_acc, end_acc);
set_slice_tile(acc_u_, w_acc_u, beg_acc, end_acc);
});
});
});
};
auto do_gemm_1 = [&](auto& acc_d_, auto& a_tile_, auto& d_tile_) {
// ar_br (areg, breg)
static_for<0, kWarpRepeatK_1, 1>{}([&](auto i_k) {
static_for<0, kWarpRepeatM_1, 1>{}([&](auto i_m) {
constexpr auto beg_a = sequence<i_m * kSubBlockM_1, i_k * kSubBlockK_1>{};
constexpr auto end_a =
sequence<(i_m + 1) * kSubBlockM_1, (i_k + 1) * kSubBlockK_1>{};
const auto w_a = get_slice_tile(a_tile_, beg_a, end_a);
static_for<0, kWarpRepeatN_1, 1>{}([&](auto i_n) {
constexpr auto beg_acc = sequence<i_m * kSubBlockM_0, i_n * kSubBlockN_0>{};
constexpr auto end_acc =
sequence<(i_m + 1) * kSubBlockM_0, (i_n + 1) * kSubBlockN_0>{};
// 3d indexing for permuted g/u/d
constexpr auto beg_b =
sequence<i_m * kBlockWarpsM_0, i_n * kSubBlockN_0, 0>{};
constexpr auto end_b =
sequence<(i_m + 1) * kBlockWarpsM_0, (i_n + 1) * kSubBlockN_0, 0>{};
auto w_acc_d = get_slice_tile(acc_d_, beg_acc, end_acc);
auto w_d = get_slice_tile(d_tile_, beg_b, end_b);
warp_gemm_1(w_acc_d, w_a, w_d);
set_slice_tile(acc_d_, w_acc_d, beg_acc, end_acc);
});
});
});
};
// start of pipeline
do_load_a0(a_sst_0, true_v);
do_load_b0(g_tile[0], u_tile[0], true_v);
do_load_a0(a_sst_1, true_v);
do_load_b0(g_tile[1], u_tile[1], true_v);
clear_tile(acc_g);
clear_tile(acc_u);
constexpr auto k_per_block_0 = Problem::FusedMoeTileShape::kK_a; index_t i_0 = 0;
const index_t loops_0 = (dim_size + k_per_block_0 - 1) / k_per_block_0; while(i_0 < (loops_0 - 2))
{
// first buffer
do_gemm_0(acc_g, acc_u, a_warp_windows_0, g_tile[0], u_tile[0]);
do_load_a0(a_sst_0, true_v);
do_load_b0(g_tile[0], u_tile[0], true_v);
i_0++;
// second buffer
do_gemm_0(acc_g, acc_u, a_warp_windows_1, g_tile[1], u_tile[1]);
do_load_a0(a_sst_1, true_v);
do_load_b0(g_tile[1], u_tile[1], true_v);
i_0++;
}
// first buffer
do_gemm_0(acc_g, acc_u, a_warp_windows_0, g_tile[0], u_tile[0]);
// prefetch
d_thread_type d_tile[2];
do_load_b1(d_tile[0], true_v);
do_load_b1(d_tile[1], true_v);
// second buffer
do_gemm_0(acc_g, acc_u, a_warp_windows_1, g_tile[1], u_tile[1]);
// redice acc_g/u
constexpr auto acc_spans_0 = decltype(acc_g)::get_distributed_spans();
sweep_tile_span(acc_spans_0[number<0>{}], [&](auto idx0) {
sweep_tile_span(acc_spans_0[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
element_wise::Silu{}(acc_g(i_j_idx), acc_g(i_j_idx));
acc_g(i_j_idx) *= acc_u(i_j_idx);
});
});
constexpr auto n_per_block_1 = Problem::FusedMoeTileShape::kN_d; const auto y = [&]() {
const index_t loops_1 = (dim_size + n_per_block_1 - 1) / n_per_block_1; if constexpr(std::is_same_v<YDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<YDataType>(acc_g);
else
return cast_tile<YDataType>(acc_g);
}();
auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset; auto acc_d = MakeCBlockTile_Gemm1<Problem>();
clear_tile(acc_d);
// TODO: reshuffle? 32x32x8 mfma can avlid LDS reshuffle
index_t i_1 == 0;
while(i_1 < (loops_1 - 2))
{
// first buffer
do_gemm_1(acc_d, y, d_tile[0]);
do_load_b1(d_tile[0], true_v);
i_1++;
// second buffer
do_gemm_1(acc_d, y, d_tile[1]);
do_load_b1(d_tile[1], true_v);
i_1++;
}
auto smem_0_window = make_tile_window( // first buffer
make_tensor_view<address_space_enum::lds>( do_gemm_0(a_warp_windows_0, g_tile[0], g_tile[1]);
smem_0, Policy::template MakeLdsStoreBlockDescriptor_A<Problem>()), i_0++;
Policy::template MakeLdsStoreBlockDescriptor_A<Problem>().get_lengths(),
{0, 0});
async_load_tile(k_lds_store(LdsSeq.at(number<0>{}))); // second buffer
for(index_t i_0 = 0; i_0 < loops_0; i_0++) {} do_gemm_0(a_warp_windows_1, g_tile[1], g_tile[1]);
i_0++;
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
......
...@@ -117,14 +117,15 @@ struct FusedMoePipelinePolicy ...@@ -117,14 +117,15 @@ struct FusedMoePipelinePolicy
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment> template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async()
{ {
constexpr index_t K_vec = Alignment constexpr index_t K_rem = KPerBlock / K_vec; constexpr index_t K_vec = Alignment;
constexpr index_t K_rem = KPerBlock / K_vec;
if constexpr(get_warp_size() < K_rem) if constexpr(get_warp_size() <= K_rem)
{ {
static_assert(K_rem % get_warp_size() == 0); static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size(); constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet"); static_assert(K_wav <= NumWarps, "do not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav; constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check"); static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav; constexpr index_t M_rep = MPerBlock / M_wav;
...@@ -150,14 +151,56 @@ struct FusedMoePipelinePolicy ...@@ -150,14 +151,56 @@ struct FusedMoePipelinePolicy
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<1>, sequence<1>,
tuple<sequence<M_rep, M_lan, M_wav>, sequence<K_lan, K_vec>>, tuple<sequence<M_rep, M_wav, M_lan>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>, tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0()
{
if constexpr(Problem::Traits::PermuteStyle ==
FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::FusedMoeTileShape::kBlockN_0;
constexpr index_t KPerBlock = Problem::FusedMoeTileShape::kBlockK_0;
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return sequence<Nr, Kr, Kw * Nw * Kv>{}; // 3D
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_1()
{
if constexpr(Problem::Traits::PermuteStyle ==
FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{
using WarpGemm = GetWarpGemm1<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::FusedMoeTileShape::kBlockN_1;
constexpr index_t KPerBlock = Problem::FusedMoeTileShape::kBlockK_1;
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return sequence<Nr, Kr, Kw * Nw * Kv>{}; // 3D
}
}
// Caution: this will require global memory pre-shuffled to follow the mfma layout // Caution: this will require global memory pre-shuffled to follow the mfma layout
// to maximize the L1/L2 channel while skip LDS // to maximize the L1/L2 channel while skip LDS
template <index_t NPerBlock, template <index_t NPerBlock,
...@@ -166,12 +209,13 @@ struct FusedMoePipelinePolicy ...@@ -166,12 +209,13 @@ struct FusedMoePipelinePolicy
index_t WavesPerBlock_K, index_t WavesPerBlock_K,
typename WarpGemm, typename WarpGemm,
index_t Alignment, index_t Alignment,
FusedMoePermuteStyle PermuteStyle = FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv> FusedMoeWeightPermuteEnum PermuteStyle =
FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled()
{ {
static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0); static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0);
if constexpr(PermuteStyle == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv) if constexpr(PermuteStyle == FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{ {
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten // permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr index_t Kv = Alignment; constexpr index_t Kv = Alignment;
...@@ -218,20 +262,18 @@ struct FusedMoePipelinePolicy ...@@ -218,20 +262,18 @@ struct FusedMoePipelinePolicy
Alignment>(); Alignment>();
} }
template <typename Problem> template <typename Problem, index_t NSplits = 2>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G(number<NSplits> = {})
{ {
constexpr auto PermuteStype = Problem::Traits::PermuteStyle; constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
if constexpr(PermuteStype == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv) if constexpr(PermuteStype == FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{ {
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u; constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kBlockN_0;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kBlockK_0;
constexpr index_t WavesPerBlock_N = Problem::Gemm0BlockWarps {} constexpr index_t WavesPerBlock_N = Problem::FusedMoeTileShape::kBlockWarpsN_0;
::at(number<1>{}); constexpr index_t WavesPerBlock_K = Problem::FusedMoeTileShape::kBlockWarpsK_0;
constexpr index_t WavesPerBlock_K = Problem::Gemm0BlockWarps {} using WarpGemm = remove_cvref_t<GetWarpGemm0<Problem>()>;
::at(number<2>{}); constexpr index_t Alignment = GetAlignment_G<Problem>();
using WarpGemm = remove_cvref_t<GetWarpGemm0<Problem>()>;
constexpr index_t Alignment = GetAlignment_G<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock, return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock, kKPerBlock,
WavesPerBlock_N, WavesPerBlock_N,
...@@ -242,20 +284,18 @@ struct FusedMoePipelinePolicy ...@@ -242,20 +284,18 @@ struct FusedMoePipelinePolicy
} }
} }
template <typename Problem> template <typename Problem, index_t NSplits = 2>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_U() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_U(number<NSplits> = {})
{ {
constexpr auto PermuteStype = Problem::Traits::PermuteStyle; constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
if constexpr(PermuteStype == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv) if constexpr(PermuteStype == FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{ {
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u; constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kBlockN_0;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kBlockK_0;
constexpr index_t WavesPerBlock_N = Problem::Gemm0BlockWarps {} constexpr index_t WavesPerBlock_N = Problem::FusedMoeTileShape::kBlockWarpsN_0;
::at(number<1>{}); constexpr index_t WavesPerBlock_K = Problem::FusedMoeTileShape::kBlockWarpsK_0;
constexpr index_t WavesPerBlock_K = Problem::Gemm0BlockWarps {} using WarpGemm = remove_cvref_t<GetWarpGemm0<Problem>()>;
::at(number<2>{}); constexpr index_t Alignment = GetAlignment_U<Problem>();
using WarpGemm = remove_cvref_t<GetWarpGemm0<Problem>()>;
constexpr index_t Alignment = GetAlignment_U<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock, return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock, kKPerBlock,
WavesPerBlock_N, WavesPerBlock_N,
...@@ -270,16 +310,14 @@ struct FusedMoePipelinePolicy ...@@ -270,16 +310,14 @@ struct FusedMoePipelinePolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D() CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
{ {
constexpr auto PermuteStype = Problem::Traits::PermuteStyle; constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
if constexpr(PermuteStype == FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv) if constexpr(PermuteStype == FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{ {
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d; constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kBlockN_1;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y; constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kBlockK_1;
constexpr index_t WavesPerBlock_N = Problem::Gemm1BlockWarps {} constexpr index_t WavesPerBlock_N = Problem::FusedMoeTileShape::kBlockWarpsN_1;
::at(number<1>{}); constexpr index_t WavesPerBlock_K = Problem::FusedMoeTileShape::kBlockWarpsK_1;
constexpr index_t WavesPerBlock_K = Problem::Gemm1BlockWarps {} using WarpGemm = remove_cvref_t<GetWarpGemm1<Problem>()>;
::at(number<2>{}); constexpr index_t Alignment = GetAlignment_D<Problem>();
using WarpGemm = remove_cvref_t<GetWarpGemm1<Problem>()>;
constexpr index_t Alignment = GetAlignment_D<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock, return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock, kKPerBlock,
WavesPerBlock_N, WavesPerBlock_N,
...@@ -290,65 +328,12 @@ struct FusedMoePipelinePolicy ...@@ -290,65 +328,12 @@ struct FusedMoePipelinePolicy
} }
} }
template <index_t MPerBlock,
index_t KPerBlock,
index_t NumWarps,
index_t Alignment,
index_t KPack,
index_t NumPrefetch>
CK_TILE_HOST_DEVICE static constexpr auto MakeSmemLoadTileDescriptor_SimpleMxK_Async()
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kBlockSize = ck_tile::get_warp_size() * NumWarps; // Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t KVector = Alignment; // this is for global load
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
static_assert(warpSize * KVector >= KPerBlock && warpSize * KVector % KPerBlock == 0);
constexpr index_t LanesPerK = KPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = MPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == MPerBlock * KPerBlock / (kBlockSize * KVector));
constexpr index_t BufferSize = NumIssues * NumWarps * (warpSize * KVector + kPad);
constexpr auto lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumPrefetch>{}, // num_buffers
number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<KPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<BufferSize>{},
number<NumWarps*(warpSize * KVector + kPad)>{},
number<warpSize * KVector + kPad>{},
number<KPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumPrefetch>{},
number<NumIssues>{},
number<LaneGroups>{},
number<NumWarps>{})),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 1, 3, 2>{}, sequence<4, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor_A() CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{ {
// A async->LDS // A async->LDS
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a; constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kBlockM_0;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kBlockK_0;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size(); constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
...@@ -359,7 +344,7 @@ struct FusedMoePipelinePolicy ...@@ -359,7 +344,7 @@ struct FusedMoePipelinePolicy
static_assert(kKPerBlock % kVector == 0); static_assert(kKPerBlock % kVector == 0);
constexpr index_t LanesPerK = kKPerBlock / kVector; // how many thread loading K constexpr index_t LanesPerK = kKPerBlock / kVector; // how many thread loading K
if constexpr(LanesPerK > warpSize) if constexpr(LanesPerK >= warpSize)
{ {
// need multiple waves to load K // need multiple waves to load K
static_assert(LanesPerK % warpSize == 0); static_assert(LanesPerK % warpSize == 0);
...@@ -433,7 +418,7 @@ struct FusedMoePipelinePolicy ...@@ -433,7 +418,7 @@ struct FusedMoePipelinePolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSmemLoadTileDistribution_A() CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
{ {
// A async->LDS // A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS // Note that, this descriptor is only to construct the layout inside LDS
...@@ -442,8 +427,8 @@ struct FusedMoePipelinePolicy ...@@ -442,8 +427,8 @@ struct FusedMoePipelinePolicy
// below code is almost the same as SmemStore dist, with difference: // below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc // 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout // 2). return discriptor is in NxK 2d layout
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a; constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kBlockM_0;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a; constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kBlockK_0;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size(); constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps; constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
...@@ -454,12 +439,12 @@ struct FusedMoePipelinePolicy ...@@ -454,12 +439,12 @@ struct FusedMoePipelinePolicy
static_assert(kKPerBlock % kVector == 0); static_assert(kKPerBlock % kVector == 0);
constexpr index_t LanesPerK = kKPerBlock / kVector; // how many thread loading K constexpr index_t LanesPerK = kKPerBlock / kVector; // how many thread loading K
if constexpr(LanesPerK > warpSize) if constexpr(LanesPerK >= warpSize)
{ {
// need multiple waves to load K // need multiple waves to load K
static_assert(LanesPerK % warpSize == 0); static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize; constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps) if constexpr(wavesPerK >= NumWarps)
{ {
// TODO: need multiple issues along K to load all data // TODO: need multiple issues along K to load all data
} }
...@@ -526,96 +511,6 @@ struct FusedMoePipelinePolicy ...@@ -526,96 +511,6 @@ struct FusedMoePipelinePolicy
return lds_desc_m_k; return lds_desc_m_k;
} }
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeASmemStoreTileDistribution()
{
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignment_A<Problem>();
constexpr index_t KPack = GetSmemKPack_A<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchA;
return MakeSmemStoreBlockDescriptor_SimpleMxK_Async<kMperBlock,
kKPerBlock,
kBlockSize,
NumWarps,
KPack,
Alignment>();
}
#if 0
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGSmemLoadTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignment_G<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGSmemStoreTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignment_G<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
return MakeSmemStoreTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeUSmemLoadTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignment_U<Problem>();
constexpr index_t KPack = GetSmemKPackU<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchU;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeDSmemLoadTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignment_D<Problem>();
constexpr index_t KPack = GetSmemKPackD<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchD;
return MakeSmemLoadTileDescriptor_SimpleMxK_Async<kNPerBlock,
kKPerBlock,
NumWarps,
Alignment,
KPack,
NumPrefetch>();
}
#endif
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0() CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0()
...@@ -640,72 +535,57 @@ struct FusedMoePipelinePolicy ...@@ -640,72 +535,57 @@ struct FusedMoePipelinePolicy
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<2>{}), Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<2>{}),
true /*TransposeC*/>{}; true /*TransposeC*/>{};
} }
#if 0
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetGemm0() CK_TILE_HOST_DEVICE constexpr auto MakeCBlockTile_Gemm0() const
{ {
using BlockGemmProblem = using TileShape = remove_cvref_t<typename Problem::FusedMoeTileShape>;
BlockGemmPipelineProblem<typename Problem::ADataType,
typename Problem::GDataType, // UDataType is the same constexpr index_t BlockWarpsM = TileShape::kBlockWarpsM_0;
typename Problem::AccDataType, constexpr index_t BlockWarpsN = TileShape::kBlockWarpsN_0;
Problem::kBlockSize, constexpr index_t WarpRepeatM = TileShape::kWarpRepeatM_0;
TileGemmShape<Problem::FusedMoeTileShape::kM_a, constexpr index_t WarpRepeatN = TileShape::kWarpRepeatN_0;
Problem::FusedMoeTileShape::kN_g * 2, using WarpGemm = remove_cvref_t<decltype(GetWarpGemm0<Problem>())>;
Problem::FusedMoeTileShape::kK_a>>;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
constexpr auto warp_gemm = []() { sequence<>,
return WarpGemmMfmaDispatcher< tuple<sequence<WarpRepeatM, BlockWarpsM>, sequence<WarpRepeatN, BlockWarpsN>>,
typename Problem::ADataType, tuple<sequence<1, 2>>,
typename Problem::GDataType, tuple<sequence<1, 1>>,
typename Problem::AccDataType, sequence<1, 2>,
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<0>{}), sequence<0, 0>>{};
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<2>{}), constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
true /*TransposeC*/>{}; c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
}(); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< return c_block_tensor;
typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
typename Problem::FusedMoeTileShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetGemm1() CK_TILE_HOST_DEVICE constexpr auto MakeCBlockTile_Gemm1() const
{ {
using BlockGemmProblem = using TileShape = remove_cvref_t<typename Problem::FusedMoeTileShape>;
BlockGemmPipelineProblem<typename Problem::YDataType,
typename Problem::DDataType, constexpr index_t BlockWarpsM = TileShape::kBlockWarpsM_1;
typename Problem::AccDataType, constexpr index_t BlockWarpsN = TileShape::kBlockWarpsN_1;
Problem::kBlockSize, constexpr index_t WarpRepeatM = TileShape::kWarpRepeatM_1;
TileGemmShape<Problem::FusedMoeTileShape::kM_a, constexpr index_t WarpRepeatN = TileShape::kWarpRepeatN_1;
Problem::FusedMoeTileShape::kN_d, using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
Problem::FusedMoeTileShape::kK_y>>;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
constexpr auto warp_gemm = []() { sequence<>,
return WarpGemmMfmaDispatcher< tuple<sequence<WarpRepeatM, BlockWarpsM>, sequence<WarpRepeatN, BlockWarpsN>>,
typename Problem::YDataType, tuple<sequence<1, 2>>,
typename Problem::DDataType, tuple<sequence<1, 1>>,
typename Problem::AccDataType, sequence<1, 2>,
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<0>{}), sequence<0, 0>>{};
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<2>{}), constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
true /*TransposeC*/>{}; c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
}(); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< return c_block_tensor;
typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
typename Problem::FusedMoeTileShape::Gemm1BlockWarps,
decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
} }
#endif
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -5,18 +5,18 @@ ...@@ -5,18 +5,18 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
// // clang-format off
// [indexing implementation-1] // [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices // using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices // each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, top_k=3, M_a = 4, input_tokens = 5 // e.g. num_experts = 6, top_k=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] // before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4 // tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float // topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
// number)
// //
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]] // token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
...@@ -25,12 +25,9 @@ ...@@ -25,12 +25,9 @@
// max_tokens_post_padded : top_k * input_tokens + num_experts * (M_a - 1) // max_tokens_post_padded : top_k * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// 0, 1, 2, 5] // |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- // sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
// exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
// //
// * length is max_tokens_post_padded, actual size is num_tokens_post_padded_ptr // * length is max_tokens_post_padded, actual size is num_tokens_post_padded_ptr
// //
...@@ -55,8 +52,7 @@ ...@@ -55,8 +52,7 @@
// [indexing implementation-2] // [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] // before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4 // tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float // topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
// number)
// //
// we generate original rol/col id as // we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]] // topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
...@@ -73,7 +69,7 @@ ...@@ -73,7 +69,7 @@
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]] // [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
// //
// //
// // clang-format on
// //
namespace ck_tile { namespace ck_tile {
...@@ -81,81 +77,45 @@ namespace ck_tile { ...@@ -81,81 +77,45 @@ namespace ck_tile {
template <typename TilePartitioner_, typename FusedMoePipeline_, typename EpiloguePipeline_> template <typename TilePartitioner_, typename FusedMoePipeline_, typename EpiloguePipeline_>
struct FusedMoeKernel struct FusedMoeKernel
{ {
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>; using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FusedMoePipeline = ck_tile::remove_cvref_t<FusedMoePipeline_>; using FusedMoePipeline = remove_cvref_t<FusedMoePipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>; // TODO: not used
static constexpr ck_tile::index_t kBlockSize = FusedMoePipeline::kBlockSize; static constexpr index_t kBlockSize = FusedMoePipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FusedMoePipeline::kBlockPerCu; static constexpr index_t kBlockPerCu = FusedMoePipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0); static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FusedMoePipeline::Problem::kBlockPerCu; static constexpr index_t kBlockPerCuInput = FusedMoePipeline::Problem::kBlockPerCu;
using ADataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::ADataType>; using ADataType = remove_cvref_t<typename FusedMoePipeline::ADataType>;
using GDataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::GDataType>; using GDataType = remove_cvref_t<typename FusedMoePipeline::GDataType>;
using UDataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::UDataType>; using UDataType = remove_cvref_t<typename FusedMoePipeline::UDataType>;
using DDataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::DDataType>; using DDataType = remove_cvref_t<typename FusedMoePipeline::DDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::ODataType>; using ODataType = remove_cvref_t<typename FusedMoePipeline::ODataType>;
using AccDataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::AccDataType>; using AccDataType = remove_cvref_t<typename FusedMoePipeline::AccDataType>;
using ScaleDataType = ck_tile::remove_cvref_t<typename FusedMoePipeline::ScaleDataType>; using ScaleDataType = remove_cvref_t<typename FusedMoePipeline::ScaleDataType>;
using DLayout = ck_tile::remove_cvref_t<typename FusedMoePipeline::DLayout>; using FusedMoeTileShape = remove_cvref_t<typename FusedMoePipeline::FusedMoeTileShape>;
using FusedMoeTileShape = ck_tile::remove_cvref_t<typename FusedMoePipeline::FusedMoeTileShape>;
static constexpr bool kPadDimSize = FusedMoePipeline::kPadDimSize; static constexpr bool kPadDimSize = FusedMoePipeline::kPadDimSize;
static constexpr bool kPadHiddenSize = FusedMoePipeline::kPadHiddenSize; static constexpr bool kPadHiddenSize = FusedMoePipeline::kPadHiddenSize;
static constexpr bool kPadSeqLenQ = FusedMoePipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FusedMoePipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FusedMoePipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FusedMoePipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FusedMoePipeline::BiasEnum;
static constexpr bool kStoreLSE = FusedMoePipeline::kStoreLSE;
static constexpr bool kHasDropout = FusedMoePipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FusedMoePipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FusedMoePipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
// clang-format off // clang-format off
template <typename T> struct t2s; template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; }; template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; }; template <> struct t2s<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; }; template <> struct t2s<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; }; template <> struct t2s<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; }; template <> struct t2s<bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on // clang-format on
CK_TILE_HOST static std::string GetName() CK_TILE_HOST static std::string GetName()
{ {
// sync with generate.py // sync with generate.py
// clang-format off // clang-format off
using bfs = typename FusedMoePipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadSeqLenK) n += "sk";
if (kPadHeadDimQ) n += "d";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<ADataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FusedMoePipeline::name) + "_" +
"v" + (std::is_same_v<DLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
// clang-format on // clang-format on
} }
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template template <index_t I> // to avoid duplicated base class prblem, introduce an template
// arg // arg
struct FusedMoeEmptyKargs struct FusedMoeEmptyKargs
{ {
}; };
...@@ -180,25 +140,31 @@ struct FusedMoeKernel ...@@ -180,25 +140,31 @@ struct FusedMoeKernel
// const void* num_tokens_post_padded_ptr; // const void* num_tokens_post_padded_ptr;
const void* num_sorted_tiles_ptr; const void* num_sorted_tiles_ptr;
ck_tile::index_t dim_size; index_t dim_size;
ck_tile::index_t hidden_size; index_t hidden_size;
ck_tile::index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
ck_tile::index_t num_experts; // number of groups index_t num_experts; // number of groups
// ck_tile::index_t top_k; // need this? // index_t top_k; // need this?
ck_tile::index_t stride_a; index_t stride_a;
ck_tile::index_t stride_g; index_t stride_gu; // assume g/u have same stride
ck_tile::index_t stride_u; // index_t stride_u;
ck_tile::index_t stride_d; index_t stride_d;
ck_tile::index_t stride_o; index_t stride_o;
ck_tile::index_t stride_g_expert; index_t stride_expert_gu; // assume g/u have same stride
ck_tile::index_t stride_u_expert; index_t stride_expert_d;
ck_tile::index_t stride_d_expert;
}; };
using Kargs = FusedMoeCommonKargs; // std::conditional_t<kIsGroupMode, FusedMoeGroupModeKargs, struct FusedMoeMatrixCoreShuffleKargs : public FusedMoeCommonKargs
// FusedMoeBatchModeKargs>; {
// batch*nr_0*kr_0*waveflattern, now stride_kr is the stride in above
index_t stride_gu_nr;
index_t stride_d_nr;
};
// TODO: switch karg based on
using Kargs = FusedMoeMatrixCoreShuffleKargs;
// host args are used inside host API // host args are used inside host API
// and should be POD data structure // and should be POD data structure
...@@ -217,21 +183,21 @@ struct FusedMoeKernel ...@@ -217,21 +183,21 @@ struct FusedMoeKernel
// const void* num_tokens_post_padded_ptr; // const void* num_tokens_post_padded_ptr;
const void* num_sorted_tiles_ptr; const void* num_sorted_tiles_ptr;
ck_tile::index_t dim_size; index_t dim_size;
ck_tile::index_t hidden_size; index_t hidden_size;
ck_tile::index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
ck_tile::index_t num_experts; // number of groups index_t num_experts; // number of groups
// ck_tile::index_t top_k; // need this? // index_t top_k; // need this?
ck_tile::index_t stride_a; index_t stride_a;
ck_tile::index_t stride_g; index_t stride_g;
ck_tile::index_t stride_u; index_t stride_u;
ck_tile::index_t stride_d; index_t stride_d;
ck_tile::index_t stride_o; index_t stride_o;
ck_tile::index_t stride_g_expert; index_t stride_expert_gu;
ck_tile::index_t stride_u_expert; index_t stride_expert_gu;
ck_tile::index_t stride_d_expert; index_t stride_expert_d;
}; };
using Hargs = FusedMoeCommonHargs; using Hargs = FusedMoeCommonHargs;
...@@ -244,45 +210,53 @@ struct FusedMoeKernel ...@@ -244,45 +210,53 @@ struct FusedMoeKernel
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return ck_tile::max(FusedMoePipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); return max(FusedMoePipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
} }
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
// allocate LDS // allocate LDS
// __shared__ char smem_ptr[GetSmemSize()]; // __shared__ char smem_ptr[GetSmemSize()];
ck_tile::index_t num_sorted_tiles = __builtin_amdgcn_readfirstlane( index_t num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const ck_tile::index_t*>(kargs.num_sorted_tiles_ptr)); *reinterpret_cast<const index_t*>(kargs.num_sorted_tiles_ptr));
ck_tile::index_t tile_id = __builtin_amdgcn_readfirstlane(blockIdx.x;);
index_t nr_0 = kargs.hidden_size / FusedMoePipeline::kBlockNr_0;
index_t kr_0 = kargs.dim_size / FusedMoePipeline::kBlockKr_0;
index_t nr_1 = kargs.dim_size / FusedMoePipeline::kBlockNr_1;
index_t kr_1 = kargs.hidden_size / FusedMoePipeline::kBlockKr_1;
__shared__ CK_TILE_LDS_ADDR ADataType smem_0[FusedMoePipeline::GetSmemSizeSingleBuffer()];
__shared__ CK_TILE_LDS_ADDR ADataType smem_1[FusedMoePipeline::GetSmemSizeSingleBuffer()];
// persistent loop // persistent loop
while(true) // while(true)
{ {
const auto [sorted_tile_id, hidden_tile_id] = const auto [sorted_tile_id, hidden_tile_id] =
TilePartitioner{}(tile_id, num_sorted_tiles, kargs.hidden_size); TilePartitioner{}(num_sorted_tiles, kargs.hidden_size);
if(sorted_tile_id >= num_sorted_tiles) if(sorted_tile_id >= num_sorted_tiles)
return; return;
ck_tile::index_t expert_id = index_t expert_id = __builtin_amdgcn_readfirstlane(
__builtin_amdgcn_readfirstlane(reinterpret_cast<const ck_tile::index_t*>( reinterpret_cast<const index_t*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along hidden_size // index along hidden_size
ck_tile::index_t hidden_id = index_t hidden_id =
__builtin_amdgcn_readfirstlane(hidden_tile_id * FusedMoeTileShape::kN_g); __builtin_amdgcn_readfirstlane(hidden_tile_id * FusedMoeTileShape::kBlockN_0);
index_t hidden_id_nr = __builtin_amdgcn_readfirstlane(hidden_tile_id * block_nr);
const auto a_coord = FusedMoePipeline::GetAIndex(); // 2d thread offset, [i_row, i_col] const auto a_coord = FusedMoePipeline::GetAIndex(); // 2d thread offset, [i_row, i_col]
const auto token_coord = const auto sorted_token_id =
a_coord[number<0>{}] + sorted_tile_id * FusedMoeTileShape::kM_a; a_coord[number<0>{}] + sorted_tile_id * FusedMoeTileShape::kBlockM_0;
index_t token_id = index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[token_coord]; reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
ScaleDataType scale = ScaleDataType scale =
reinterpret_cast<const ScaleDataType*>(kargs.sorted_weight_ptr)[token_coord]; reinterpret_cast<const ScaleDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const auto a_gtile_window = [&]() { const auto a_gtile_window = [&]() {
// A is already pre-padded in previous kernel
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr); const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
const auto a_view_ = make_naive_tensor_view<address_space_enum::global>( const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
a_ptr, a_ptr,
...@@ -299,116 +273,101 @@ struct FusedMoeKernel ...@@ -299,116 +273,101 @@ struct FusedMoeKernel
make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
const auto a_gtile_window_ = make_tile_window( const auto a_gtile_window_ =
a_gather_view_, make_tile_window(a_gather_view_,
make_tuple(number<FusedMoeTileShape::kM_a>{}, number<FmhaPipeline::kK_a>{}), make_tuple(number<FusedMoeTileShape::kBlockM_0>{},
{0, 0}); number<FusedMoePipeline::kBlockK_0>{}),
{0, 0});
return a_gtile_window_; return a_gtile_window_;
}(); }();
// TODO: gtile using NSub to have less register pressure
const auto g_gtile_window = [&]() { const auto g_gtile_window = [&]() {
const GDataType* g_ptr = const GDataType* g_ptr =
reinterpret_cast<const GDataType*>(kargs.g_ptr) + reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * kargs.stride_g_expert + static_cast<long_index_t>(expert_id) * kargs.stride_expert_gu +
hidden_id * kargs.stride_g; hidden_id_nr * kargs.stride_gu_nr;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>( const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr, g_ptr,
make_tuple(kargs.hidden_size, kargs.dim_size), make_tuple(nr_0, kr_0, number<FusedMoePipeline::kBlockWaveFlatten>{}),
make_tuple(kargs.stride_g, 1), make_tuple(stride_gu_nr, number<FusedMoePipeline::kBlockWaveFlatten>{}, 1),
number<FusedMoePipeline::kAlignmentG>{}, number<FusedMoePipeline::kAlignmentG>{},
number<1>{}); number<1>{});
const auto g_view_1_ = pad_tensor_view( const auto g_view_1_ =
g_view_, pad_tensor_view(g_view_,
make_tuple(number<FusedMoeShape::kN_g>{}, number<FusedMoeShape::kK_a>{}), make_tuple(number<FusedMoePipeline::kBlockNr_0>{},
sequence<kPadHiddenSize, kPadDimSize>{}); number<FusedMoePipeline::kBlockKr_0>{},
number<FusedMoePipeline::kBlockWaveFlatten>{}),
const auto g_gtile_window_ = make_tile_window( sequence<kPadHiddenSize, kPadDimSize, 0>{});
g_view_1_,
make_tuple(number<FusedMoeTileShape::kN_g>{}, number<FmhaPipeline::kK_a>{}), const auto g_gtile_window_ =
{0, 0}); make_tile_window(g_view_1_,
make_tuple(number<FusedMoeTileShape::kBlockNr_0>{},
number<FusedMoePipeline::kBlockKr_0>{},
number<FusedMoePipeline::kBlockWaveFlatten>{}),
{0, 0, 0});
return g_gtile_window_; return g_gtile_window_;
}(); }();
const auto u_gtile_window = [&]() { const auto u_gtile_window = [&]() {
const UDataType* u_ptr = const UDataType* u_ptr =
reinterpret_cast<const UDataType*>(kargs.u_ptr) + reinterpret_cast<const UDataType*>(kargs.u_ptr) +
static_cast<long_index_t>(expert_id) * kargs.stride_u_expert + static_cast<long_index_t>(expert_id) * kargs.stride_expert_gu +
hidden_id * kargs.stride_u; hidden_id_nr * kargs.stride_gu_nr;
const auto u_view_ = make_naive_tensor_view<address_space_enum::global>( const auto u_view_ = make_naive_tensor_view<address_space_enum::global>(
u_ptr, u_ptr,
make_tuple(kargs.hidden_size, kargs.dim_size), make_tuple(nr_0, kr_0, number<FusedMoePipeline::kBlockWaveFlatten>{}),
make_tuple(kargs.stride_u, 1), make_tuple(stride_gu_nr, number<FusedMoePipeline::kBlockWaveFlatten>{}, 1),
number<FusedMoePipeline::kAlignmentU>{}, number<FusedMoePipeline::kAlignmentU>{},
number<1>{}); number<1>{});
const auto u_view_1_ = pad_tensor_view( const auto u_view_1_ =
u_view_, pad_tensor_view(u_view_,
make_tuple(number<FusedMoeShape::kN_u>{}, number<FusedMoeShape::kK_a>{}), make_tuple(number<FusedMoePipeline::kBlockNr_0>{},
sequence<kPadHiddenSize, kPadDimSize>{}); number<FusedMoePipeline::kBlockKr_0>{},
const auto u_gtile_window_ = make_tile_window( number<FusedMoePipeline::kBlockWaveFlatten>{}),
u_view_1_, sequence<kPadHiddenSize, kPadDimSize, 0>{});
make_tuple(number<FusedMoeShape::kN_u>{}, number<FusedMoeShape::kK_a>{}), const auto u_gtile_window_ =
{0, 0}); make_tile_window(u_view_1_,
make_tuple(number<FusedMoeTileShape::kBlockNr_0>{},
number<FusedMoePipeline::kBlockKr_0>{},
number<FusedMoePipeline::kBlockWaveFlatten>{}),
{0, 0, 0});
return u_gtile_window_; return u_gtile_window_;
}(); }();
const auto d_gtile_window = [&]() { const auto d_gtile_window = [&]() {
const DDataType* d_ptr = [&]() { const DDataType* d_ptr = [&]() {
if constexpr(std::is_same_v<DLayout, ck_tile::tensor_layout::gemm::RowMajor>) reinterpret_cast<const DDataType*>(kargs.d_ptr) +
{ static_cast<long_index_t>(expert_id) * kargs.stride_expert_d +
reinterpret_cast<const DDataType*>(kargs.d_ptr) + hidden_id_nr* kargs.stride_d_nr;
static_cast<long_index_t>(expert_id) * kargs.stride_d_expert +
hidden_id* kargs.stride_d;
}
else
{
reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * kargs.stride_d_expert +
hidden_id;
}
}(); }();
if constexpr(std::is_same_v<DLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>( d_ptr,
d_ptr, make_tuple(nr_1, kr_1, FusedMoePipeline::kBlockWaveFlatten),
make_tuple(kargs.hidden_size, kargs.dim_size), make_tuple(kargs.stride_d_nr, FusedMoePipeline::kBlockWaveFlatten, 1),
make_tuple(kargs.stride_d, 1), number<FusedMoePipeline::kAlignmentD>{},
number<FusedMoePipeline::kAlignmentD>{}, number<1>{});
number<1>{}); const auto d_view_1_ =
const auto d_view_1_ = pad_tensor_view( pad_tensor_view(d_view_,
d_view_, make_tuple(number<FusedMoePipeline::kBlockNr_1>{},
make_tuple(number<FusedMoeShape::kK_y>{}, number<FusedMoeShape::kN_d>{}), number<FusedMoePipeline::kBlockKr_1>{},
sequence<kPadHiddenSize, kPadDimSize>{}); number<FusedMoePipeline::kBlockWaveFlatten>{}),
sequence<kPadDimSize, kPadHiddenSize, 0>{});
const auto d_gtile_window_ = make_tile_window(
d_view_1_, const auto d_gtile_window_ =
make_tuple(number<FusedMoeShape::kK_y>{}, number<FusedMoeShape::kN_d>{}), make_tile_window(d_view_1_,
{0, 0}); make_tuple(number<FusedMoePipeline::kBlockNr_1>{},
return d_gtile_window_; number<FusedMoePipeline::kBlockKr_1>{},
} number<FusedMoePipeline::kBlockWaveFlatten>{}),
else {0, 0, 0});
{ return d_gtile_window_;
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(kargs.dim_size, kargs.hidden_size),
make_tuple(kargs.stride_d, 1),
number<FusedMoePipeline::kAlignmentD>{},
number<1>{});
const auto d_view_1_ = pad_tensor_view(
d_view_,
make_tuple(number<FusedMoeShape::kN_d>{}, number<FusedMoeShape::kK_y>{}),
sequence<kPadHiddenSize, kPadDimSize>{});
const auto d_gtile_window_ = make_tile_window(
d_view_1_,
make_tuple(number<FusedMoeShape::kN_d>{}, number<FusedMoeShape::kK_y>{}),
{0, 0});
return d_gtile_window_;
}
}(); }();
auto o_gtile_window = [&]() { auto o_gtile_window = [&]() {
const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr); const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr);
const auto o_view_ = make_naive_tensor_view<address_space_enum::global>( const auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
o_ptr, o_ptr,
make_tuple(kargs.num_tokens, kargs.dim_size), make_tuple(kargs.num_tokens, kargs.dim_size),
make_tuple(kargs.stride_o, 1), make_tuple(kargs.stride_o, 1),
...@@ -423,10 +382,11 @@ struct FusedMoeKernel ...@@ -423,10 +382,11 @@ struct FusedMoeKernel
make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
const auto o_gtile_window_ = make_tile_window( const auto o_gtile_window_ =
o_scatter_view_, make_tile_window(o_scatter_view_,
make_tuple(number<FusedMoeTileShape::kM_a>{}, number<FmhaPipeline::kK_a>{}), make_tuple(number<FusedMoeTileShape::kBlockM_0>{},
{0, 0}); number<FusedMoePipeline::kBlockN_1>{}),
{0, 0});
return o_gtile_window_; return o_gtile_window_;
}(); }();
...@@ -436,9 +396,13 @@ struct FusedMoeKernel ...@@ -436,9 +396,13 @@ struct FusedMoeKernel
u_gtile_window, u_gtile_window,
d_gtile_window, d_gtile_window,
o_gtile_window, o_gtile_window,
scale); scale,
smem_0,
tile_id += gridDim.x; smem_1,
kargs.dim_size,
kargs.hidden_size);
// tile_id += gridDim.x;
// epilogue not used
} }
} }
}; };
......
...@@ -45,4 +45,31 @@ struct FusedMoeTilePartitioner_PersistentSplitD ...@@ -45,4 +45,31 @@ struct FusedMoeTilePartitioner_PersistentSplitD
} }
}; };
template <typename FusedMoeTileShape_>
struct FusedMoeTilePartitioner_Linear
{
using Shape = ck_tile::remove_cvref_t<FusedMoeTileShape_>;
static constexpr const char* name = "2d"; // expert x hidden
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/,
ck_tile::index_t /*hidden_size*/))
{
index_t i_n = blockIdx.x;
index_t i_m = blockIdx.y;
return ck_tile::make_tuple(i_m, i_n);
}
// persistent
CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t hidden_size)
{
// TODO: this may need tuning
index_t grids = num_cu * blocks_per_cu;
index_t ms = ck_tile::integer_divide_ceil(max_tokens, Shape::kBlockM_0);
index_t ns = ck_tile::integer_divide_ceil(hidden_size, Shape::kBlockN_0);
return dim3(ns, ms, 1);
}
};
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
/*
This pipeline split the gemm-n of B matrix for less register pressure
(assume B matrix is much larger than A)
*/
template <typename Problem_, typename Policy_ = FusedMoePipelineNSplit2Policy>
struct FusedMoePipelineNSplit2
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using GDataType = remove_cvref_t<typename Problem::GDataType>;
using UDataType = remove_cvref_t<typename Problem::UDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ScaleDataType = remove_cvref_t<typename Problem::ScaleDataType>;
using FusedMoeTileShape = remove_cvref_t<typename Problem::FusedMoeTileShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kBlockNSub_0 = FusedMoeTileShape::kBlockNSub_0;
static constexpr index_t kBlockM_0 = FusedMoeTileShape::kBlockM_0;
static constexpr index_t kBlockN_0 = FusedMoeTileShape::kBlockN_0;
static constexpr index_t kBlockK_0 = FusedMoeTileShape::kBlockK_0;
static constexpr index_t kWarpM_0 = FusedMoeTileShape::kWarpM_0;
static constexpr index_t kWarpN_0 = FusedMoeTileShape::kWarpN_0;
static constexpr index_t kWarpK_0 = FusedMoeTileShape::kWarpK_0;
static constexpr index_t kBlockWarpsM_0 = FusedMoeTileShape::kBlockWarpsM_0;
static constexpr index_t kBlockWarpsN_0 = FusedMoeTileShape::kBlockWarpsN_0;
static constexpr index_t kBlockWarpsK_0 = FusedMoeTileShape::kBlockWarpsK_0;
static constexpr index_t kSubBlockM_0 = FusedMoeTileShape::kSubBlockM_0;
static constexpr index_t kSubBlockN_0 = FusedMoeTileShape::kSubBlockN_0;
static constexpr index_t kSubBlockK_0 = FusedMoeTileShape::kSubBlockK_0;
static constexpr index_t kWarpRepeatM_0 = FusedMoeTileShape::kWarpRepeatM_0;
static constexpr index_t kWarpRepeatN_0 = FusedMoeTileShape::kWarpRepeatN_0;
static constexpr index_t kWarpRepeatK_0 = FusedMoeTileShape::kWarpRepeatK_0;
static_assert(kBlockN_0 == 2 * kBlockNSub_0); // this pipeline only support split2
static_assert(kWarpRepeatN_0 % 2 == 0);
static constexpr index_t kBlockM_1 = FusedMoeTileShape::kBlockM_1;
static constexpr index_t kBlockN_1 = FusedMoeTileShape::kBlockN_1;
static constexpr index_t kBlockK_1 = FusedMoeTileShape::kBlockK_1;
static constexpr index_t kWarpM_1 = FusedMoeTileShape::kWarpM_1;
static constexpr index_t kWarpN_1 = FusedMoeTileShape::kWarpN_1;
static constexpr index_t kWarpK_1 = FusedMoeTileShape::kWarpK_1;
static constexpr index_t kBlockWarpsM_1 = FusedMoeTileShape::kBlockWarpsM_1;
static constexpr index_t kBlockWarpsN_1 = FusedMoeTileShape::kBlockWarpsN_1;
static constexpr index_t kBlockWarpsK_1 = FusedMoeTileShape::kBlockWarpsK_1;
static constexpr index_t kSubBlockM_1 = FusedMoeTileShape::kSubBlockM_1;
static constexpr index_t kSubBlockN_1 = FusedMoeTileShape::kSubBlockN_1;
static constexpr index_t kSubBlockK_1 = FusedMoeTileShape::kSubBlockK_1;
static constexpr index_t kWarpRepeatM_1 = FusedMoeTileShape::kWarpRepeatM_1;
static constexpr index_t kWarpRepeatN_1 = FusedMoeTileShape::kWarpRepeatN_1;
static constexpr index_t kWarpRepeatK_1 = FusedMoeTileShape::kWarpRepeatK_1;
using MBlockType_0 = decltype(Policy::GetMatrixCoreSwizzledBlockTIle_0<Problem>());
static constexpr index_t kBlockNr_0 = MBlockType_0 {}
::at(number<0>{});
static constexpr index_t kBlockKr_0 = MBlockType_0 {}
::at(number<1>{});
static constexpr index_t kBlockWaveFlatten = MBlockType_0 {}
::at(number<2>{});
static_assert(kBlockNr_0 % 2 == 0);
static constexpr index_t kBlockSubNr_0 = kBlockNr_0 / 2;
using MBlockType_1 = decltype(Policy::GetMatrixCoreSwizzledBlockTIle_1<Problem>());
static constexpr index_t kBlockNr_1 = MBlockType_1 {}
::at(number<0>{});
static constexpr index_t kBlockKr_1 = MBlockType_1 {}
::at(number<1>{});
static constexpr index_t kBlockSubKr_1 = kBlockKr_1 / 2;
static_assert(kBlockSubNr_0 == kBlockSubKr_1);
static constexpr index_t kAlignmentA = Policy::GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::GetAlignment_G<Problem>();
static constexpr index_t kAlignmentU = Policy::GetAlignment_U<Problem>();
static constexpr index_t kAlignmentD = Policy::GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::GetAlignment_O<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
return 2;
}
}();
static constexpr const char* name = "fused_moe_ns2";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeSingleBuffer()
{
return Policy<Problem>::GetSmemSizeSingleBuffer();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetAIndex()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetOIndex()
{
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
const auto o_coord = o_dist.calculate_index();
return o_coord;
}
template <typename AGlobalTensorView,
typename GGlobalTileWindow,
typename UGlobalTileWindow,
typename DGlobalTileWindow,
typename OGlobalTensorView>
CK_TILE_DEVICE auto operator()(const AGlobalTensorView& a_gtile_window_tmp,
const GGlobalTileWindow& g_gtile_window_tmp,
const UGlobalTileWindow& u_gtile_window_tmp,
const DGlobalTileWindow& d_gtile_window_tmp,
OGlobalTensorView& o_gtile_window_tmp,
// const void * sorted_weight_ptr,
ScaleDataType scale,
CK_TILE_LDS_ADDR void* smem_0,
CK_TILE_LDS_ADDR void* smem_1,
index_t dim_size,
index_t /*hidden_size*/)
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
auto a_win = make_tile_window(a_gtile_window_tmp.get_bottom_tensor_view(),
a_gtile_window_tmp.get_window_lengths(),
a_gtile_window_tmp.get_window_origin(),
Policy::template MakeGlobalTileDistribution_A<Problem>());
auto g_win = generate_tuple(
[&](auto i) {
return make_tile_window(g_gtile_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kBlockSubNr_0>{},
number<kBlockKr_0>{},
number<kBlockWaveFlatten>{}),
{number<kBlockSubNr_0 * i>{}, I0, I0},
Policy::template MakeGlobalTileDistribution_G<Problem>());
},
number<2>{});
auto u_win = generate_tuple(
[&](auto i) {
return make_tile_window(u_gtile_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kBlockSubNr_0>{},
number<kBlockKr_0>{},
number<kBlockWaveFlatten>{}),
{number<kBlockSubNr_0 * i>{}, I0, I0},
Policy::template MakeGlobalTileDistribution_U<Problem>());
},
number<2>{});
auto d_win = generate_tuple(
[&](auto i) {
return make_tile_window(d_gtile_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kBlockNr_1>{},
number<kBlockSubKr_1>{},
number<kBlockWaveFlatten>{}),
{I0, number<kBlockSubKr_1 * i>{}, I0},
Policy::template MakeGlobalTileDistribution_U<Problem>());
},
number<2>{});
make_tile_window(d_gtile_window_tmp.get_bottom_tensor_view(),
d_gtile_window_tmp.get_window_lengths(),
d_gtile_window_tmp.get_window_origin(),
Policy::template MakeGlobalTileDistribution_D<Problem>());
auto o_win = make_tile_window(o_gtile_window_tmp.get_bottom_tensor_view(),
o_gtile_window_tmp.get_window_lengths(),
o_gtile_window_tmp.get_window_origin(),
Policy::template MakeOGlobalTileDistribution<Problem>());
using g_thread_type = decltype(load_tile(g_win[I0]));
using u_thread_type = decltype(load_tile(u_win[I0]));
using d_thread_type = decltype(load_tile(d_win[I0]));
const index_t loops_0 = (dim_size + kBlockK_0 - 1) / kBlockK_0;
const index_t loops_1 = (dim_size + kBlockN_1 - 1) / kBlockN_1;
// issues_warps_lanes
auto a_st0 = make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
// issues_warps_lanes
auto a_st1 = make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
// m*k
auto a_ld0 = make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0});
// m*k
auto a_ld1 = make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0});
statically_indexed_array<g_thread_type, 2> g_tls;
statically_indexed_array<u_thread_type, 2> u_tls;
using WarpGemm0 = Policy::GetWarpGemm0<Problem>();
using WarpGemm1 = Policy::GetWarpGemm1<Problem>();
auto warp_gemm_0 = WarpGemm0{};
auto warp_gemm_1 = WarpGemm1{};
// TODO: N fist, M next
// create and pre-cache a_reg warp-window
auto make_a_warp_windows = [&](auto a_sld_) {
const index_t i_mwarp_0 = get_warp_id() / kBlockWarpsN_0;
// construct A-warp-window
auto warp_window = make_tile_window(
a_sld_.get_bottom_tensor_view(),
make_tuple(number<WarpGemm0::kM>{}, number<WarpGemm0::kK>{}),
a_sld_.get_window_origin() + multi_index<2>{i_mwarp_0 * WarpGemm0::kM, 0},
make_static_tile_distribution(typename WarpGemm0::AWarpDstrEncoding{}));
return warp_window;
};
auto a_warp_windows_0 = make_a_warp_windows(a_ld0);
auto a_warp_windows_1 = make_a_warp_windows(a_ld1);
auto load_a = [&](auto& a_store_) {
async_load_tile(a_store_, a_win);
move_tile_window(a_win, {number<0>{}, number<kBlockK_0>{}});
};
auto load_n = [&](auto& g_tile_, auto& u_tile_, auto& g_window_, auto& u_window_) {
g_tile_ = load_tile(g_window_);
u_tile_ = load_tile(u_window_);
move_tile_window(g_window_, {number<0>{}, number<kBlockKr_0>{}, number<0>{}});
move_tile_window(u_window_, {number<0>{}, number<kBlockKr_0>{}, number<0>{}});
};
auto load_d = [&](auto& d_tile_) {
d_tile_ = load_tile(d_win);
move_tile_window(d_win, {number<0>{}, number<kBlockKr_0>{}, number<0>{}});
};
auto acc_g = generate_tuple([&](auto) { MakeCBlockTile_Gemm0<Problem>(); }, number<2>{});
auto acc_u = generate_tuple([&](auto) { MakeCBlockTile_Gemm0<Problem>(); }, number<2>{});
// Note this function only do gemm of single Nsplit
// clang-format off
auto gemm_0 = [&](auto& acc_g_, auto& acc_u_, auto& a_, auto& g_, auto& u_) {
static_for<0, kWarpRepeatK_0, 1>{}([&](auto i_k) {
static_for<0, kWarpRepeatM_0, 1>{}([&](auto i_m) {
constexpr auto beg_a = sequence<i_m * kSubBlockM_0, i_k * kSubBlockK_0 >{};
constexpr auto end_a = sequence<(i_m+1) * kSubBlockM_0, (i_k+1) * kSubBlockK_0 >{};
auto w_a = get_slice_tile(a_, beg_a, end_a);
static_for<0, kWarpRepeatN_0 / 2, 1>{}([&](auto i_n) {
constexpr auto beg_acc = sequence<i_m * kSubBlockM_0, i_n * kSubBlockN_0>{};
constexpr auto end_acc = sequence<(i_m + 1) * kSubBlockM_0, (i_n + 1) * kSubBlockN_0>{};
constexpr auto beg_b = sequence<i_n * kSubBlockN_0, i_k * kSubBlockK_0, 0>{};
constexpr auto end_b = sequence<(i_n + 1) * kSubBlockN_0, (i_k + 1) * kSubBlockK_0, 0>{};
auto w_acc_g = get_slice_tile(acc_g_, beg_acc, end_acc);
auto w_acc_u = get_slice_tile(acc_u_, beg_acc, end_acc);
auto w_g = get_slice_tile(g_, beg_b, end_b);
auto w_u = get_slice_tile(u_, beg_b, end_b);
warp_gemm_0(w_acc_g, w_a, w_g);
warp_gemm_0(w_acc_u, w_a, w_u);
set_slice_tile(acc_g_, w_acc_g, beg_acc, end_acc);
set_slice_tile(acc_u_, w_acc_u, beg_acc, end_acc);
});
});
});
};
// clang-format on
// clang-format off
auto gemm_1 = [&](auto& acc_d_, auto& y_, auto& d_) {
static_for<0, kWarpRepeatK_1, 1>{}([&](auto i_k) {
static_for<0, kWarpRepeatM_1, 1>{}([&](auto i_m) {
constexpr auto beg_a = sequence<i_m * kSubBlockM_1, i_k * kSubBlockK_1>{};
constexpr auto end_a = sequence<(i_m + 1) * kSubBlockM_1, (i_k + 1) * kSubBlockK_1>{};
const auto w_y = get_slice_tile(y_, beg_a, end_a);
static_for<0, kWarpRepeatN_1, 1>{}([&](auto i_n) {
constexpr auto beg_acc = sequence<i_m * kSubBlockM_1, i_n * kSubBlockN_1>{};
constexpr auto end_acc = sequence<(i_m + 1) * kSubBlockM_1, (i_n + 1) * kSubBlockN_1>{};
constexpr auto beg_d = sequence<i_n * kSubBlockN_1, i_k * kSubBlockK_1, 0>{};
constexpr auto end_d = sequence<(i_n + 1) * kSubBlockN_1, (i_k + 1) * kSubBlockK_1, 0>{};
auto w_acc_d = get_slice_tile(acc_d_, beg_acc, end_acc);
auto w_d = get_slice_tile(d_, beg_d, end_d);
warp_gemm_1(w_acc_d, w_y, w_d);
set_slice_tile(acc_d_, w_acc_d, beg_acc, end_acc);
});
});
});
};
// clang-format on
constexpr auto issues_a = number<a_win.get_num_of_access()>{};
constexpr auto issues_g = number<g_win[I0].get_num_of_access()>{};
constexpr auto issues_u = number<u_win[I0].get_num_of_access()>{};
constexpr auto issues_b = issues_g + issues_u;
constexpr auto issues_d = number<d_win[I0].get_num_of_access()>{};
constexpr auto issues_o = number<o_win.get_num_of_access()>{};
// start of pipeline
// clang-format off
load_a(a_st0);
load_n(g_tls[I0], u_tls[I0], g_win[I0], u_win[I0]);
load_n(g_tls[I1], u_tls[I1], g_win[I1], u_win[I1]);
load_a(a_st1);
clear_tile(acc_g[I0]); clear_tile(acc_g[I1]); clear_tile(acc_u[I0]); clear_tile(acc_u[I1]);
auto a_reg = decltype(load_tile(a_warp_windows_0)){};
index_t i_0 = 0;
while(i_0 < (loops_0 - 2))
{
// first buffer
buffer_load_fence(issues_b + issues_b + issues_a);
wave_barrier(); a_reg = load_tile(a_warp_windows_0);
buffer_load_fence(issues_b + issues_a);
gemm_0(acc_g[I0], acc_u[I0], a_reg, g_tls[I0], u_tls[I0]);
load_n(g_tls[I0], u_tls[I0], g_win[I0], u_win[I0]);
buffer_load_fence(issues_b + issues_a);
gemm_0(acc_g[I1], acc_u[I1], a_reg, g_tls[I1], u_tls[I1]);
load_n(g_tls[I1], u_tls[I1], g_win[I1], u_win[I1]);
load_a(a_st0);
i_0++;
// second buffer
buffer_load_fence(issues_b + issues_b + issues_a);
wave_barrier(); a_reg = load_tile(a_warp_windows_1);
buffer_load_fence(issues_b + issues_a);
gemm_0(acc_g[I0], acc_u[I0], a_reg, g_tls[I0], u_tls[I0]);
load_n(g_tls[I0], u_tls[I0], g_win[I0], u_win[I0]);
buffer_load_fence(issues_b + issues_a);
gemm_0(acc_g[I1], acc_u[I1], a_reg, g_tls[I1], u_tls[I1]);
load_n(g_tls[I1], u_tls[I1], g_win[I1], u_win[I1]);
load_a(a_st1);
i_0++;
}
// first buffer
buffer_load_fence(issues_b + issues_b + issues_a);
wave_barrier(); a_reg = load_tile(a_warp_windows_0);
gemm_0(acc_g[I0], acc_u[I0], a_reg, g_tls[I0], u_tls[I0]);
load_n(g_tls[I0], u_tls[I0], g_win[I0], u_win[I0]);
buffer_load_fence(issues_b + issues_a);
gemm_0(acc_g[I1], acc_u[I1], a_reg, g_tls[I1], u_tls[I1]);
load_n(g_tls[I1], u_tls[I1], g_win[I1], u_win[I1]);
// second buffer
buffer_load_fence(issues_b + issues_b);
wave_barrier(); a_reg = load_tile(a_warp_windows_1);
buffer_load_fence(issues_b);
gemm_0(acc_g[I0], acc_u[I0], a_reg, g_tls[I0], u_tls[I0]);
// prefetch
statically_indexed_array<d_thread_type, 2> d_tls;
load_d(d_tls[0]); load_d(d_tls[1]);
buffer_load_fence(issues_d + issues_d);
gemm_0(acc_g[I1], acc_u[I1], a_reg, g_tls[I1], u_tls[I1]);
// redice acc_g/u
constexpr auto acc_spans_0 = decltype(acc_g)::get_distributed_spans();
sweep_tile_span(acc_spans_0[number<0>{}], [&](auto idx0) {
sweep_tile_span(acc_spans_0[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
element_wise::Silu{}(acc_g[I0](i_j_idx), acc_g[I0](i_j_idx));
element_wise::Silu{}(acc_g[I1](i_j_idx), acc_g[I1](i_j_idx));
acc_g[I0](i_j_idx) *= acc_u[I0](i_j_idx);
acc_g[I1](i_j_idx) *= acc_u[I1](i_j_idx);
});
});
const auto y_reg = generate_tuple([&](auto i) {
if constexpr(std::is_same_v<YDataType, fp16_t>) return impl::cast_tile_pk_fp16_fp32<YDataType>(acc_g[i]);
else return cast_tile<YDataType>(acc_g[i]); }, number<2>{});
auto acc_d = MakeCBlockTile_Gemm1<Problem>();
// TODO: reshuffle? 32x32x8 mfma can avlid LDS reshuffle
// Second gemm
clear_tile(acc_d);
// first buffer
buffer_load_fence(issues_d);
gemm_1(acc_d, y_reg[I0], d_tls[I0]); load_d(d_tls[I0]);
// second buffer
buffer_load_fence(issues_d);
gemm_1(acc_d, y_reg[I1], d_tls[I1]); load_d(d_tls[I1]);
update_tile(o_win, acc_d);
index_t i_1 = 0;
while(i_1 < (loops_1 - 2))
{
clear_tile(acc_d);
// first buffer
buffer_load_fence(issues_d + issues_o);
gemm_1(acc_d, y_reg[I0], d_tls[I0]); load_d(d_tls[I0]);
// second buffer
buffer_load_fence(issues_d + issues_o);
gemm_1(acc_d, y_reg[I1], d_tls[I1]); load_d(d_tls[I1]);
update_tile(o_win, acc_d);
i_1++;
}
clear_tile(acc_d);
// first buffer
buffer_load_fence(issues_d + issues_o);
gemm_1(acc_d, y_reg[I0], d_tls[I0]);
// second buffer
buffer_load_fence(issues_o);
gemm_1(acc_d, y_reg[I1], d_tls[I1]);
update_tile(o_win, acc_d);
// clang-format on
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
struct FusedMoePipelineNSplit2Policy
{
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{
// TODO:
return 1;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
{
// using async
static constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
static constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
{
static constexpr index_t copy_bytes = [&]() { return 16; }();
static constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_U()
{
static constexpr index_t copy_bytes = [&]() { return 16; }();
static constexpr index_t data_bytes = sizeof(typename Problem::UDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
{
static constexpr index_t copy_bytes = [&]() { return 16; }();
static constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_O()
{
if constexpr(Problem::Traits::OAtomic == 0)
{
// pack fp16/bf16 atomic
static_assert(sizeof(typename Problem::ODataType) == 2);
return 2;
}
else if constexpr(Problem::Traits::OAtomic == 1)
{
// fp32 atomic
return 1;
}
else
{
return 16 / sizeof(typename Problem::ODataType);
}
}
template <typename DataType_>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack()
{
// TODO: this is for 3d layout
return 16 / sizeof(remove_cvref_t<typename Problem::DataType_>);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A()
{
return GetSmemKPack<typename Problem::ADataType>();
}
#if 0
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWaveFlattenShape()
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
return sequence<Kw, Nw, Kv>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockTileNrKr()
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
return sequence<Problem::FusedMoeTileShape::kBlockK_0 / Nw,
Problem::FusedMoeTileShape::kBlockK_0 / (Kw * Kv)>{};
}
#endif
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeSingleBuffer()
{
constexpr a_sld_desc = MakeLdsLoadDesc_A<Problem>();
constexpr a_sst_desc = MakeLdsStoreDesc_A<Problem>();
static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size());
return a_sld_desc.get_element_space_size();
}
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK()
{
constexpr index_t K_vec = Alignment constexpr index_t K_rem = KPerBlock / K_vec;
if constexpr(get_warp_size() < K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr index_t K_lan = K_rem;
constexpr index_t M_lan = get_warp_size() / K_lan;
constexpr index_t M_wav = NumWarps;
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav, M_lan>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
// optimized version for async
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async()
{
constexpr index_t K_vec = Alignment;
constexpr index_t K_rem = KPerBlock / K_vec;
if constexpr(get_warp_size() <= K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "do not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
// NOTE: no swap, but hard to avoid LDS bank conflict
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr index_t K_lan = K_rem;
constexpr index_t M_lan = get_warp_size() / K_lan;
constexpr index_t M_wav = NumWarps;
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
// NOTE: swapped for LDS load bank conflict free
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav, M_lan>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0()
{
if constexpr(Problem::Traits::PermuteStyle ==
FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::FusedMoeTileShape::kBlockN_0;
constexpr index_t KPerBlock = Problem::FusedMoeTileShape::kBlockK_0;
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return sequence<Nr, Kr, Kw * Nw * Kv>{}; // 3D
}
}
// Caution: this will require global memory pre-shuffled to follow the mfma layout
// to maximize the L1/L2 channel while skip LDS
template <index_t NPerBlock,
index_t KPerBlock,
index_t WavesPerBlock_N,
index_t WavesPerBlock_K,
typename WarpGemm,
index_t Alignment,
FusedMoeWeightPermuteEnum PermuteStyle =
FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_MatrixCore_Swizzled()
{
static_assert(Alignment % WarpGemm::WarpGemmAttribute::Impl::kABKPerLane == 0);
if constexpr(PermuteStyle == FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
constexpr index_t Nr_p = WavesPerBlock_N;
constexpr index_t Kr_p = WavesPerBlock_K;
constexpr index_t Nr_y = Nr / Nr_p;
constexpr index_t Kr_y = Kr / Kr_p;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>, // 0
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>,
// Nr_p, Kr_p Kw Nw
tuple<sequence<1, 2>, sequence<3, 3>>,
tuple<sequence<1, 1>, sequence<0, 1>>,
// Nr_y Kr_y Kv
sequence<1, 2, 3>,
sequence<0, 0, 2>>{});
// clang-format on
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A()
{
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kM_a;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t Alignment = GetAlignment_A<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<kMPerBlock,
kKPerBlock,
NumWarps,
Alignment>();
}
template <typename Problem, index_t NSplits = 2>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G(number<NSplits> = {})
{
constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
if constexpr(PermuteStype == FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kBlockN_0;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kBlockK_0;
constexpr index_t WavesPerBlock_N = Problem::FusedMoeTileShape::kBlockWarpsN_0;
constexpr index_t WavesPerBlock_K = Problem::FusedMoeTileShape::kBlockWarpsK_0;
using WarpGemm = remove_cvref_t<GetWarpGemm0<Problem>()>;
constexpr index_t Alignment = GetAlignment_G<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock,
WavesPerBlock_N,
WavesPerBlock_K,
WarpGemm,
Alignment,
PermuteStype>();
}
}
template <typename Problem, index_t NSplits = 2>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_U(number<NSplits> = {})
{
constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
if constexpr(PermuteStype == FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kBlockN_0;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kBlockK_0;
constexpr index_t WavesPerBlock_N = Problem::FusedMoeTileShape::kBlockWarpsN_0;
constexpr index_t WavesPerBlock_K = Problem::FusedMoeTileShape::kBlockWarpsK_0;
using WarpGemm = remove_cvref_t<GetWarpGemm0<Problem>()>;
constexpr index_t Alignment = GetAlignment_U<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock,
WavesPerBlock_N,
WavesPerBlock_K,
WarpGemm,
Alignment,
PermuteStype>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
{
constexpr auto PermuteStype = Problem::Traits::PermuteStyle;
if constexpr(PermuteStype == FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kBlockN_1;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kBlockK_1;
constexpr index_t WavesPerBlock_N = Problem::FusedMoeTileShape::kBlockWarpsN_1;
constexpr index_t WavesPerBlock_K = Problem::FusedMoeTileShape::kBlockWarpsK_1;
using WarpGemm = remove_cvref_t<GetWarpGemm1<Problem>()>;
constexpr index_t Alignment = GetAlignment_D<Problem>();
return MakeGlobalTileDistribution_MatrixCore_Swizzled<kNPerBlock,
kKPerBlock,
WavesPerBlock_N,
WavesPerBlock_K,
WarpGemm,
Alignment,
PermuteStype>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kBlockM_0;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kBlockK_0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps
static_assert(kKPerBlock % kVector == 0);
constexpr index_t LanesPerK = kKPerBlock / kVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = kMPerBlock / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + kPad)>{}, // m1
number<warpSize * KVector + kPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = kMPerBlock / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<kKPerBlock>{}, // m1
number<warpSize * KVector + kPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr index_t kMPerBlock = Problem::FusedMoeTileShape::kBlockM_0;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kBlockK_0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t kPad = KPack; // pad between warps
static_assert(kKPerBlock % kVector == 0);
constexpr index_t LanesPerK = kKPerBlock / kVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK >= NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = kMPerBlock / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + kPad)>{}, // m1
number<warpSize * KVector + kPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
make_merge_transform(make_tuple(
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = kMPerBlock / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{}, // m0
number<kKPerBlock>{}, // m1
number<warpSize * KVector + kPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0()
{
return WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<0>{}),
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm0WarpTile::at(number<2>{}),
true /*TransposeC*/>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
{
return WarpGemmMfmaDispatcher<typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<0>{}),
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<1>{}),
Problem::FusedMoeTileShape::Gemm1WarpTile::at(number<2>{}),
true /*TransposeC*/>{};
}
template <typename Problem, index_t NSplits = 2>
CK_TILE_HOST_DEVICE constexpr auto MakeCBlockTile_Gemm0(number<NSplits> = {}) const
{
using TileShape = remove_cvref_t<typename Problem::FusedMoeTileShape>;
constexpr index_t BlockWarpsM = TileShape::kBlockWarpsM_0;
constexpr index_t BlockWarpsN = TileShape::kBlockWarpsN_0;
constexpr index_t WarpRepeatM = TileShape::kWarpRepeatM_0;
constexpr index_t WarpRepeatN = TileShape::kWarpRepeatN_0;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm0<Problem>())>;
static_assert(WarpRepeatN % NSplits == 0);
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<WarpRepeatM, BlockWarpsM>, sequence<WarpRepeatN / NSplits, BlockWarpsN>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename Problem>
CK_TILE_HOST_DEVICE constexpr auto MakeCBlockTile_Gemm1() const
{
using TileShape = remove_cvref_t<typename Problem::FusedMoeTileShape>;
constexpr index_t BlockWarpsM = TileShape::kBlockWarpsM_1;
constexpr index_t BlockWarpsN = TileShape::kBlockWarpsN_1;
constexpr index_t WarpRepeatM = TileShape::kWarpRepeatM_1;
constexpr index_t WarpRepeatN = TileShape::kWarpRepeatN_1;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<WarpRepeatM, BlockWarpsM>, sequence<WarpRepeatN, BlockWarpsN>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_0()
{
if constexpr(Problem::Traits::PermuteStyle ==
FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{
using WarpGemm = GetWarpGemm0<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::FusedMoeTileShape::kBlockN_0;
constexpr index_t KPerBlock = Problem::FusedMoeTileShape::kBlockK_0;
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return sequence<Nr, Kr, Kw * Nw * Kv>{}; // 3D
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetMatrixCoreSwizzledBlockTIle_1()
{
if constexpr(Problem::Traits::PermuteStyle ==
FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv)
{
using WarpGemm = GetWarpGemm1<Problem>{}; // assume warpgemm0/1 are the same
constexpr index_t NPerBlock = Problem::FusedMoeTileShape::kBlockN_1;
constexpr index_t KPerBlock = Problem::FusedMoeTileShape::kBlockK_1;
constexpr index_t Kv = GetAlignment_G<{Problem}>();
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
return sequence<Nr, Kr, Kw * Nw * Kv>{}; // 3D
}
}
};
} // namespace ck_tile
...@@ -33,17 +33,7 @@ struct FusedMoePipelineProblem ...@@ -33,17 +33,7 @@ struct FusedMoePipelineProblem
static constexpr index_t kBlockSize = FusedMoeTileShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = FusedMoeTileShape::NumWarps * get_warp_size();
// attributes from traits
// static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
// static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
// static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
// static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
// static constexpr auto BiasEnum = Traits::BiasEnum;
// static constexpr bool kStoreLSE = Traits::kStoreLSE;
// static constexpr bool kHasDropout = Traits::kHasDropout;
// static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
using GateActivation = remove_cvref_t<typename Traits::GateActivation_>; // using GateActivation = remove_cvref_t<typename Traits::GateActivation_>;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -37,12 +37,11 @@ M_a | A | | | | | | | | | ...@@ -37,12 +37,11 @@ M_a | A | | | | | | | | |
SILU x-----x +----------+ SILU x-----x +----------+
K_y = N_g = N_u dim K_y = N_g = N_u dim
*/ */
template <typename BlockTile_, // sequence<M_a, N_g, N_u, K_a, N_d template <typename BlockTile_, // sequence<M_a, N_g, N_sub0, K_a, N_d
typename Gemm0BlockWarps_, typename Gemm0BlockWarps_,
typename Gemm0WarpTile_, typename Gemm0WarpTile_,
typename Gemm1BlockWarps_, typename Gemm1BlockWarps_,
typename Gemm1WarpTile_, typename Gemm1WarpTile_>
bool IsDLayoutRowMajor_>
struct FusedMoeTileShape struct FusedMoeTileShape
{ {
using BlockTile = remove_cvref_t<BlockTile_>; using BlockTile = remove_cvref_t<BlockTile_>;
...@@ -58,25 +57,60 @@ struct FusedMoeTileShape ...@@ -58,25 +57,60 @@ struct FusedMoeTileShape
static constexpr index_t kM_a = BlockTile::at(number<0>{}); static constexpr index_t kM_a = BlockTile::at(number<0>{});
static constexpr index_t kN_g = BlockTile::at(number<1>{}); static constexpr index_t kN_g = BlockTile::at(number<1>{});
static constexpr index_t kN_u = BlockTile::at(number<2>{}); static constexpr index_t kN_u = BlockTile::at(number<1>{});
static constexpr index_t kK_a = BlockTile::at(number<3>{}); // e.g. N_g = 256, n_sub_gu=128, then we split blockN of G/U into 2 parts to loopover
static constexpr index_t kN_d = BlockTile::at(number<4>{}); // this can help B matrix direct-to-register using too much vgpr issue
static_assert(kN_g == kN_u); static constexpr index_t kN_sub0 = BlockTile::at(number<2>{});
static constexpr index_t kK_a = BlockTile::at(number<3>{});
static constexpr index_t kN_d = BlockTile::at(number<4>{});
// static_assert(kN_g == kN_u);
static constexpr index_t kK_y = kN_g; static constexpr index_t kK_y = kN_g;
static constexpr index_t kM_0 = kM_a; static constexpr index_t kBlockNSub_0 = kN_sub0; // allow partial
static constexpr index_t kN_0 = kN_g; // note N will x2 static constexpr index_t kBlockM_0 = kM_a;
static constexpr index_t kK_0 = kK_a; static constexpr index_t kBlockN_0 = kN_g; // note N will x2 in real pipeline for gemm-0
static constexpr index_t kBlockK_0 = kK_a;
static constexpr index_t kWarpM_0 = Gemm0WarpTile::at(number<0>{});
static constexpr index_t kWarpN_0 = Gemm0WarpTile::at(number<1>{});
static constexpr index_t kWarpK_0 = Gemm0WarpTile::at(number<2>{});
static constexpr index_t kBlockWarpsM_0 = Gemm0BlockWarps::at(numner<0>{});
static constexpr index_t kBlockWarpsN_0 = Gemm0BlockWarps::at(numner<1>{});
static constexpr index_t kBlockWarpsK_0 = Gemm0BlockWarps::at(numner<2>{});
static constexpr index_t kSubBlockM_0 = kWarpM_0 * kBlockWarpsM_0;
static constexpr index_t kSubBlockN_0 = kWarpN_0 * kBlockWarpsN_0;
static constexpr index_t kSubBlockK_0 = kWarpK_0 * kBlockWarpsK_0;
static_assert(kBlockM_0 % kSubBlockM_0 == 0);
static_assert(kBlockN_0 % kSubBlockN_0 == 0);
static_assert(kBlockK_0 % kSubBlockK_0 == 0);
static constexpr index_t kWarpRepeatM_0 = kBlockM_0 / kSubBlockM_0;
static constexpr index_t kWarpRepeatN_0 = kBlockN_0 / kSubBlockN_0;
static constexpr index_t kWarpRepeatK_0 = kBlockK_0 / kSubBlockK_0;
static constexpr index_t kM_1 = kM_0; static constexpr index_t kBlockKSub_1 = kBlockNSub_0;
static constexpr index_t kN_1 = kN_d; static constexpr index_t kBlockM_1 = kM_a;
static constexpr index_t kK_1 = kN_g; static constexpr index_t kBlockN_1 = kN_d;
static constexpr index_t kBlockK_1 = kN_g;
static constexpr index_t kWarpM_1 = Gemm1WarpTile::at(number<0>{});
static constexpr index_t kWarpN_1 = Gemm1WarpTile::at(number<1>{});
static constexpr index_t kWarpK_1 = Gemm1WarpTile::at(number<2>{});
static constexpr index_t kBlockWarpsM_1 = Gemm1BlockWarps::at(numner<0>{});
static constexpr index_t kBlockWarpsN_1 = Gemm1BlockWarps::at(numner<1>{});
static constexpr index_t kBlockWarpsK_1 = Gemm1BlockWarps::at(numner<2>{});
static constexpr index_t kSubBlockM_1 = kWarpM_1 * kBlockWarpsM_1;
static constexpr index_t kSubBlockN_1 = kWarpN_1 * kBlockWarpsN_1;
static constexpr index_t kSubBlockK_1 = kWarpK_1 * kBlockWarpsK_1;
static_assert(kBlockM_1 % kSubBlockM_1 == 0);
static_assert(kBlockN_1 % kSubBlockN_1 == 0);
static_assert(kBlockK_1 % kSubBlockK_1 == 0);
static constexpr index_t kWarpRepeatM_1 = kBlockM_1 / kSubBlockM_1;
static constexpr index_t kWarpRepeatN_1 = kBlockN_1 / kSubBlockN_1;
static constexpr index_t kWarpRepeatK_1 = kBlockK_1 / kSubBlockK_1;
// d, rowmajor : hidden*dim, colmajor : dim*hidden (vLLM use this layout) // d, rowmajor : hidden*dim, colmajor : dim*hidden (vLLM use this layout)
static constexpr bool IsDLayoutRowMajor = IsDLayoutRowMajor_; // static constexpr bool IsDLayoutRowMajor = IsDLayoutRowMajor_;
using DLayout = std::conditional_t<IsDLayoutRowMajor, // using DLayout = std::conditional_t<IsDLayoutRowMajor,
ck_tile::tensor_layout::gemm::RowMajor, // ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>; // ck_tile::tensor_layout::gemm::ColumnMajor>;
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp"
namespace ck_tile {
template <bool DownPreShuffled_ = false,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */,
index_t OAtomic_ = 0,
FusedMoeWeightPermuteEnum WeightPermute_ =
FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv>
struct FusedMoeTraits
{
static constexpr bool DownPreShuffled = DownPreShuffled_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr FusedMoeWeightPermuteEnum WeightPermute = WeightPermute_;
static constexpr index_t OAtomic = OAtomic_; // 0-pack fp16/bf16 atomic, 1-fp32 atomic
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
enum class FusedMoeWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
no_permute = 999,
};
}
...@@ -616,11 +616,51 @@ struct buffer_store_if<1> ...@@ -616,11 +616,51 @@ struct buffer_store_if<1>
} }
}; };
CK_TILE_DEVICE void buffer_load_fence_raw(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
template <index_t cnt>
CK_TILE_DEVICE void buffer_load_fence_raw(number<cnt>)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
#if 0
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
{
// const index_t origin_cnt = 0x0f70;
__builtin_amdgcn_s_waitcnt(0x0f70 | cnt);
}
#endif
template <index_t cnt>
CK_TILE_DEVICE void buffer_load_fence(number<cnt>)
{
/*
simm16, simm16[3:0] -> bits[3:0], simm16[15:14] -> bits[5:4]
*/
static_assert(cnt < 64);
constexpr index_t low = cnt & 0xf;
constexpr index_t hi = (cnt & 0x30) << 14;
constexpr index_t c = 0x0f70 | low | hi;
__builtin_amdgcn_s_waitcnt(c);
}
CK_TILE_DEVICE void wave_barrier() { __builtin_amdgcn_s_barrier(); }
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
{ {
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
template <index_t cnt>
CK_TILE_DEVICE auto async_load_fence(number<cnt>)
{
buffer_load_fence(number<cnt>{});
}
// clang-format off // clang-format off
namespace impl{ namespace impl{
...@@ -706,13 +746,13 @@ CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by) ...@@ -706,13 +746,13 @@ CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by)
} }
// clang-format on // clang-format on
template <typename... T> template <typename... T>
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0, T&... o) CK_TILE_DEVICE void buffer_load_fence_raw(index_t cnt = 0, T&... o)
{ {
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
impl::insert_dummy_dep(o...); impl::insert_dummy_dep(o...);
} }
CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) CK_TILE_DEVICE void buffer_store_fence_raw(index_t cnt = 0)
{ {
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
...@@ -976,6 +1016,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, ...@@ -976,6 +1016,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <bool pre_nop = false> template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -998,10 +1048,12 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, ...@@ -998,10 +1048,12 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
: "memory"); : "memory");
} }
#if 0
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
{ {
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
#endif
// memory coherency bit for buffer store/load instruction // memory coherency bit for buffer store/load instruction
// check ISA manual for each GFX target // check ISA manual for each GFX target
...@@ -1365,6 +1417,45 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, ...@@ -1365,6 +1417,45 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem, // reinterpret_cast<CK_TILE_LDS_ADDR
// uint32_t*>(reinterpret_cast<uintptr_t>(smem)),
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem, // reinterpret_cast<CK_TILE_LDS_ADDR
// uint32_t*>(reinterpret_cast<uintptr_t>(smem)),
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
}
template <index_t N, template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default> amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data, CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
...@@ -2094,6 +2185,28 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, ...@@ -2094,6 +2185,28 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{}); smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
} }
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load<T, N, coherence>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
0,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
// buffer_store requires: // buffer_store requires:
// 1) p_dst_wave must point to global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
...@@ -2221,16 +2334,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_ ...@@ -2221,16 +2334,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif #endif
} }
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T, index_t NumElemsPerThread> template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset, const index_t global_offset,
......
...@@ -43,6 +43,20 @@ ...@@ -43,6 +43,20 @@
#define CK_TILE_HOST_DEVICE_EXTERN #define CK_TILE_HOST_DEVICE_EXTERN
#endif #endif
// implementing the "memory address space" attribute
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
#ifdef __HIPCC_
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
#else
#define CK_TILE_GENERIC_ADDR
#define CK_TILE_GLOBAL_ADDR
#define CK_TILE_LDS_ADDR
#define CK_TILE_BUF_RES_ADDR
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE #ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code #define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif #endif
......
...@@ -369,6 +369,31 @@ struct buffer_view<address_space_enum::global, ...@@ -369,6 +369,31 @@ struct buffer_view<address_space_enum::global,
dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{}); dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
} }
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto async_get(CK_TILE_LDS_ADDR remove_cvref_t<T>* smem,
index_t i,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
// X is vector of T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
smem, cached_buf_res_, i, is_valid_element, bool_constant<oob_conditional_check>{});
}
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool pre_nop = false, bool pre_nop = false,
......
...@@ -49,6 +49,26 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -49,6 +49,26 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
// while creating the smem window, which can enable compiler properly detect the
// dependency if using multiple smem window (multiple buffer)
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto
async_load_tile(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
...@@ -69,11 +89,6 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -69,11 +89,6 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename WindowLengths> template <typename WindowLengths>
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&) CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
{ {
......
...@@ -187,4 +187,30 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten ...@@ -187,4 +187,30 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
}); });
} }
namespace detail {
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template <typename X, typename Y>
struct is_similiar_distributed_tensor
{
static constexpr bool value = false;
};
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
static_distributed_tensor<TypeY, DistY>>
{
using Tx = static_distributed_tensor<TypeX, DistX>;
using Ty = static_distributed_tensor<TypeY, DistY>;
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
};
template <typename X, typename Y>
inline constexpr bool is_similiar_distributed_tensor_v =
is_similiar_distributed_tensor<X, Y>::value;
} // namespace detail
} // namespace ck_tile } // namespace ck_tile
...@@ -104,6 +104,23 @@ struct tensor_view ...@@ -104,6 +104,23 @@ struct tensor_view
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const
{
return buf_.template async_get<X>(
smem,
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}
template <typename X, template <typename X,
bool pre_nop = false, bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
......
...@@ -495,6 +495,74 @@ struct tile_window_with_static_distribution ...@@ -495,6 +495,74 @@ struct tile_window_with_static_distribution
}); });
} }
template <typename LdsTileWindow_, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
smem += size_per_issue; // Note we manually increase the per-issue offset
}
});
});
}
template <bool oob_conditional_check = true> template <bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor, CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
......
...@@ -39,7 +39,7 @@ struct Default2DEpilogue ...@@ -39,7 +39,7 @@ struct Default2DEpilogue
if constexpr(kPadM || kPadN) if constexpr(kPadM || kPadN)
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence(); buffer_store_fence_raw();
} }
else else
{ {
......
...@@ -274,8 +274,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -274,8 +274,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
store_tile(lse_acc_dram_window_tmp, store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc)); tile_elementwise_in(lse_acc_element_func, lse_acc));
} }
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) buffer_load_fence_raw(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?) // otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it // Note: here occ are all cleard, return it
return o_acc; return o_acc;
...@@ -315,7 +315,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -315,7 +315,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); buffer_load_fence_raw(k_dram_window.get_num_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q); // auto q_tile = q; // tile_elementwise_in(q_element_func, q);
...@@ -338,7 +338,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -338,7 +338,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access()); async_load_fence_raw(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc, gemm_0(s_acc,
...@@ -360,7 +360,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -360,7 +360,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if constexpr(k0_loops <= 2) if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
async_load_fence(); async_load_fence_raw();
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment