Commit 199f7f71 authored by carlushuang's avatar carlushuang
Browse files

modify moe

parent 33ceea62
...@@ -3,16 +3,8 @@ ...@@ -3,16 +3,8 @@
#pragma once #pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
enum class FusedMoeWeightPermuteEnum
enum class FusedMoePermuteStyle
{ {
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6 // permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6 // permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
...@@ -20,14 +12,4 @@ enum class FusedMoePermuteStyle ...@@ -20,14 +12,4 @@ enum class FusedMoePermuteStyle
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
no_permute = 999, no_permute = 999,
}; };
}
template <bool DownPreShuffled_ = false,
FusedMoePermuteStyle PermuteStyle_ = FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct FusedMoeTraits
{
static constexpr bool DownPreShuffled = DownPreShuffled_;
static constexpr FusedMoePermuteStyle PermuteStyle = PermuteStyle_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile
...@@ -14,7 +14,7 @@ namespace ck_tile { ...@@ -14,7 +14,7 @@ namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
struct BlockFmhaPipelineQRKSVSAsync struct FusedMoePipeline
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
...@@ -27,43 +27,49 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -27,43 +27,49 @@ struct BlockFmhaPipelineQRKSVSAsync
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ScaleDataType = remove_cvref_t<typename Problem::ScaleDataType>; using ScaleDataType = remove_cvref_t<typename Problem::ScaleDataType>;
using FusedMoeTileShape = remove_cvref_t<typename Problem::FusedMoeTileShape>; using FusedMoeTileShape = remove_cvref_t<typename Problem::FusedMoeTileShape>;
using VLayout = remove_cvref_t<typename FusedMoeTileShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr index_t kBlockM_0 = FusedMoeTileShape::kBlockM_0;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) static constexpr index_t kBlockN_0 = FusedMoeTileShape::kBlockN_0;
// only need special care about seq_k padding (oob need set -INF of p instead of zero) static constexpr index_t kBlockK_0 = FusedMoeTileShape::kBlockK_0;
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && static constexpr index_t kWarpM_0 = FusedMoeTileShape::kWarpM_0;
Problem::kPadHeadDimV == true); static constexpr index_t kWarpN_0 = FusedMoeTileShape::kWarpN_0;
static constexpr bool kPadSeqLenQ = true; static constexpr index_t kWarpK_0 = FusedMoeTileShape::kWarpK_0;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr index_t kBlockWarpsM_0 = FusedMoeTileShape::kBlockWarpsM_0;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) static constexpr index_t kBlockWarpsN_0 = FusedMoeTileShape::kBlockWarpsN_0;
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr index_t kBlockWarpsK_0 = FusedMoeTileShape::kBlockWarpsK_0;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr index_t kSubBlockM_0 = FusedMoeTileShape::kSubBlockM_0;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr index_t kSubBlockN_0 = FusedMoeTileShape::kSubBlockN_0;
static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr index_t kSubBlockK_0 = FusedMoeTileShape::kSubBlockK_0;
static constexpr index_t kWarpRepeatM_0 = FusedMoeTileShape::kWarpRepeatM_0;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) static constexpr index_t kWarpRepeatN_0 = FusedMoeTileShape::kWarpRepeatN_0;
// ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kWarpRepeatK_0 = FusedMoeTileShape::kWarpRepeatK_0;
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>(); static constexpr index_t kBlockM_1 = FusedMoeTileShape::kBlockM_1;
static constexpr index_t kAlignmentV = []() { static constexpr index_t kBlockN_1 = FusedMoeTileShape::kBlockN_1;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) static constexpr index_t kBlockK_1 = FusedMoeTileShape::kBlockK_1;
return Policy::template GetAlignmentV<Problem>(); static constexpr index_t kWarpM_1 = FusedMoeTileShape::kWarpM_1;
else static constexpr index_t kWarpN_1 = FusedMoeTileShape::kWarpN_1;
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>(); static constexpr index_t kWarpK_1 = FusedMoeTileShape::kWarpK_1;
}(); static constexpr index_t kBlockWarpsM_1 = FusedMoeTileShape::kBlockWarpsM_1;
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>(); static constexpr index_t kBlockWarpsN_1 = FusedMoeTileShape::kBlockWarpsN_1;
static constexpr index_t kAlignmentBias = static constexpr index_t kBlockWarpsK_1 = FusedMoeTileShape::kBlockWarpsK_1;
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>(); static constexpr index_t kSubBlockM_1 = FusedMoeTileShape::kSubBlockM_1;
static constexpr index_t kSubBlockN_1 = FusedMoeTileShape::kSubBlockN_1;
#if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr index_t kSubBlockK_1 = FusedMoeTileShape::kSubBlockK_1;
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>; static constexpr index_t kWarpRepeatM_1 = FusedMoeTileShape::kWarpRepeatM_1;
#endif static constexpr index_t kWarpRepeatN_1 = FusedMoeTileShape::kWarpRepeatN_1;
static constexpr index_t kWarpRepeatK_1 = FusedMoeTileShape::kWarpRepeatK_1;
using MBlockType = decltype(GetMatrixCoreSwizzledBlockTIle_0<Problem>());
static constexpr index_t kBlockNr_0 = MBlockType {}
::at(number<0>{});
static constexpr index_t kBlockKr_0 = MBlockType {}
::at(number<1>{});
static constexpr index_t kBlockWaveFlatten = MBlockType {}
::at(number<2>{});
static constexpr index_t kBlockPerCu = []() { static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1) if constexpr(Problem::kBlockPerCu != -1)
...@@ -71,37 +77,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -71,37 +77,7 @@ struct BlockFmhaPipelineQRKSVSAsync
else else
{ {
// minimize occupancy // minimize occupancy
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) return 2;
{
return 1;
}
if constexpr(kK0BlockLength <= 32)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 64)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2;
else
return 3;
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
{
return 1;
}
} }
}(); }();
...@@ -179,23 +155,261 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -179,23 +155,261 @@ struct BlockFmhaPipelineQRKSVSAsync
o_gtile_window_tmp.get_window_lengths(), o_gtile_window_tmp.get_window_lengths(),
o_gtile_window_tmp.get_window_origin(), o_gtile_window_tmp.get_window_origin(),
Policy::template MakeOGlobalTileDistribution<Problem>()); Policy::template MakeOGlobalTileDistribution<Problem>());
using g_thread_type = decltype(load_tile(g_gtile_window));
using u_thread_type = decltype(load_tile(u_gtile_window));
using d_thread_type = decltype(load_tile(d_gtile_window));
const index_t loops_0 = (dim_size + kBlockK_0 - 1) / kBlockK_0;
const index_t loops_1 = (dim_size + kBlockN_1 - 1) / kBlockN_1;
// auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset;
// issues_warps_lanes
auto a_sst_0 =
make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
// issues_warps_lanes
auto a_sst_1 =
make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
// m*k
auto a_sld_0 = make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0});
// m*k
auto a_sld_1 = make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0});
g_thread_type g_tile[2];
using WarpGemm0 = Policy::GetWarpGemm0<Problem>();
using WarpGemm1 = Policy::GetWarpGemm1<Problem>();
auto warp_gemm_0 = WarpGemm0{};
auto warp_gemm_1 = WarpGemm1{};
// TODO: N fist, M next
const index_t i_mwarp_0 = get_warp_id() / kBlockWarpsN_0;
// create and pre-cache a warp-window
auto make_a_warp_windows = [&](auto a_sld_) {
// construct A-warp-window
auto warp_window = make_tile_window(
a_sld_.get_bottom_tensor_view(),
make_tuple(number<WarpGemm0::kM>{}, number<WarpGemm0::kK>{}),
a_sld_.get_window_origin() + multi_index<2>{i_mwarp_0 * WarpGemm0::kM, 0},
make_static_tile_distribution(typename WarpGemm0::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(warp_window), kWarpRepeatK_0>,
kWarpRepeatM_0>
ws;
// pre-cache the warp windows
static_for<0, kWarpRepeatM_0, 1>{}([&](auto i_m_iter) {
static_for<0, kWarpRepeatK_0, 1>{}([&](auto i_k_iter) {
ws(i_m_iter)(i_k_iter) = warp_window;
move_tile_window(ws(i_m_iter)(i_k_iter),
{i_m_iter * NPerBlockPerIter, i_k_iter * KPerBlockPerIter});
});
});
return ws;
};
auto a_warp_windows_0 = make_a_warp_windows(a_sld_0);
auto a_warp_windows_1 = make_a_warp_windows(a_sld_1);
constexpr auto true_v = bool_constant<true>{};
constexpr auto false_v = bool_constant<false>{};
auto do_load_a0 = [&](auto& a_store_, auto move_) {
async_load_tile(a_store_, a_gtile_window);
if constexpr(move_)
move_tile_window(a_gtile_window, {number<0>{}, number<kBlockK_0>{}});
};
auto do_load_b0 = [&](auto& g_tile_, auto& u_tile_, auto move_) {
g_tile_ = load_tile(g_gtile_window);
u_tile_ = load_tile(u_gtile_window);
if constexpr(move_)
{
move_tile_window(g_gtile_window, {number<0>{}, number<kBlockKr_0>{}, number<0>{}});
move_tile_window(u_gtile_window, {number<0>{}, number<kBlockKr_0>{}, number<0>{}});
}
};
auto do_load_b1 = [&](auto& d_tile_, auto move_) {
d_tile_ = load_tile(d_gtile_window);
if constexpr(move_)
{
move_tile_window(d_gtile_window, {number<0>{}, number<kBlockKr_0>{}, number<0>{}});
}
};
// using AWarpTensor = typename decltype(warp_gemm_0)::AWarpTensor{};
// using CWarpTensor =
auto acc_g = MakeCBlockTile_Gemm0<Problem>();
auto acc_u = MakeCBlockTile_Gemm0<Problem>();
// async_load_tile(a_sst_0, a_gtile_window); move_tile_window(a_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); g_tile[0] = load_tile(g_gtile_window);
// move_tile_window(g_gtile_window, {number<0>{}, number<kBlockK_0>{}}); u_tile[0] =
// load_tile(u_gtile_window); move_tile_window(u_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); async_load_tile(a_sst_1, a_gtile_window);
// move_tile_window(a_gtile_window, {number<0>{}, number<kBlockK_0>{}}); g_tile[1] =
// load_tile(g_gtile_window); move_tile_window(g_gtile_window, {number<0>{},
// number<kBlockK_0>{}}); u_tile[1] = load_tile(u_gtile_window);
// move_tile_window(u_gtile_window, {number<0>{}, number<kBlockK_0>{}});
auto do_gemm_0 =
[&](auto& acc_g_, auto& acc_u_, auto& a_windows_, auto& g_tile_, auto& u_tile_) {
// as_br (asmem, breg)
static_for<0, kWarpRepeatK_0, 1>{}([&](auto i_k) {
static_for<0, kWarpRepeatM_0, 1>{}([&](auto i_m) {
const auto w_a = load_tile(a_windows_(i_m)(i_k));
static_for<0, kWarpRepeatN_0, 1>{}([&](auto i_n) {
constexpr auto beg_acc =
sequence<i_m * kSubBlockM_0, i_n * kSubBlockN_0>{};
constexpr auto end_acc =
sequence<(i_m + 1) * kSubBlockM_0, (i_n + 1) * kSubBlockN_0>{};
// 3d indexing for permuted g/u/d
constexpr auto beg_b =
sequence<i_m * kBlockWarpsM_0, i_n * kSubBlockN_0, 0>{};
constexpr auto end_b =
sequence<(i_m + 1) * kBlockWarpsM_0, (i_n + 1) * kSubBlockN_0, 0>{};
auto w_acc_g = get_slice_tile(acc_g_, beg_acc, end_acc);
auto w_acc_u = get_slice_tile(acc_u_, beg_acc, end_acc);
auto w_g = get_slice_tile(g_tile_, beg_b, end_b);
auto w_u = get_slice_tile(u_tile_, beg_b, end_b);
warp_gemm_0(w_acc_g, w_a, w_g);
warp_gemm_0(w_acc_u, w_a, w_u);
set_slice_tile(acc_g_, w_acc_g, beg_acc, end_acc);
set_slice_tile(acc_u_, w_acc_u, beg_acc, end_acc);
});
});
});
};
auto do_gemm_1 = [&](auto& acc_d_, auto& a_tile_, auto& d_tile_) {
// ar_br (areg, breg)
static_for<0, kWarpRepeatK_1, 1>{}([&](auto i_k) {
static_for<0, kWarpRepeatM_1, 1>{}([&](auto i_m) {
constexpr auto beg_a = sequence<i_m * kSubBlockM_1, i_k * kSubBlockK_1>{};
constexpr auto end_a =
sequence<(i_m + 1) * kSubBlockM_1, (i_k + 1) * kSubBlockK_1>{};
const auto w_a = get_slice_tile(a_tile_, beg_a, end_a);
static_for<0, kWarpRepeatN_1, 1>{}([&](auto i_n) {
constexpr auto beg_acc = sequence<i_m * kSubBlockM_0, i_n * kSubBlockN_0>{};
constexpr auto end_acc =
sequence<(i_m + 1) * kSubBlockM_0, (i_n + 1) * kSubBlockN_0>{};
// 3d indexing for permuted g/u/d
constexpr auto beg_b =
sequence<i_m * kBlockWarpsM_0, i_n * kSubBlockN_0, 0>{};
constexpr auto end_b =
sequence<(i_m + 1) * kBlockWarpsM_0, (i_n + 1) * kSubBlockN_0, 0>{};
auto w_acc_d = get_slice_tile(acc_d_, beg_acc, end_acc);
auto w_d = get_slice_tile(d_tile_, beg_b, end_b);
warp_gemm_1(w_acc_d, w_a, w_d);
set_slice_tile(acc_d_, w_acc_d, beg_acc, end_acc);
});
});
});
};
// start of pipeline
do_load_a0(a_sst_0, true_v);
do_load_b0(g_tile[0], u_tile[0], true_v);
do_load_a0(a_sst_1, true_v);
do_load_b0(g_tile[1], u_tile[1], true_v);
clear_tile(acc_g);
clear_tile(acc_u);
constexpr auto k_per_block_0 = Problem::FusedMoeTileShape::kK_a; index_t i_0 = 0;
const index_t loops_0 = (dim_size + k_per_block_0 - 1) / k_per_block_0; while(i_0 < (loops_0 - 2))
{
// first buffer
do_gemm_0(acc_g, acc_u, a_warp_windows_0, g_tile[0], u_tile[0]);
do_load_a0(a_sst_0, true_v);
do_load_b0(g_tile[0], u_tile[0], true_v);
i_0++;
// second buffer
do_gemm_0(acc_g, acc_u, a_warp_windows_1, g_tile[1], u_tile[1]);
do_load_a0(a_sst_1, true_v);
do_load_b0(g_tile[1], u_tile[1], true_v);
i_0++;
}
// first buffer
do_gemm_0(acc_g, acc_u, a_warp_windows_0, g_tile[0], u_tile[0]);
// prefetch
d_thread_type d_tile[2];
do_load_b1(d_tile[0], true_v);
do_load_b1(d_tile[1], true_v);
// second buffer
do_gemm_0(acc_g, acc_u, a_warp_windows_1, g_tile[1], u_tile[1]);
// redice acc_g/u
constexpr auto acc_spans_0 = decltype(acc_g)::get_distributed_spans();
sweep_tile_span(acc_spans_0[number<0>{}], [&](auto idx0) {
sweep_tile_span(acc_spans_0[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
element_wise::Silu{}(acc_g(i_j_idx), acc_g(i_j_idx));
acc_g(i_j_idx) *= acc_u(i_j_idx);
});
});
constexpr auto n_per_block_1 = Problem::FusedMoeTileShape::kN_d; const auto y = [&]() {
const index_t loops_1 = (dim_size + n_per_block_1 - 1) / n_per_block_1; if constexpr(std::is_same_v<YDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<YDataType>(acc_g);
else
return cast_tile<YDataType>(acc_g);
}();
auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset; auto acc_d = MakeCBlockTile_Gemm1<Problem>();
clear_tile(acc_d);
// TODO: reshuffle? 32x32x8 mfma can avlid LDS reshuffle
index_t i_1 == 0;
while(i_1 < (loops_1 - 2))
{
// first buffer
do_gemm_1(acc_d, y, d_tile[0]);
do_load_b1(d_tile[0], true_v);
i_1++;
// second buffer
do_gemm_1(acc_d, y, d_tile[1]);
do_load_b1(d_tile[1], true_v);
i_1++;
}
auto smem_0_window = make_tile_window( // first buffer
make_tensor_view<address_space_enum::lds>( do_gemm_0(a_warp_windows_0, g_tile[0], g_tile[1]);
smem_0, Policy::template MakeLdsStoreBlockDescriptor_A<Problem>()), i_0++;
Policy::template MakeLdsStoreBlockDescriptor_A<Problem>().get_lengths(),
{0, 0});
async_load_tile(k_lds_store(LdsSeq.at(number<0>{}))); // second buffer
for(index_t i_0 = 0; i_0 < loops_0; i_0++) {} do_gemm_0(a_warp_windows_1, g_tile[1], g_tile[1]);
i_0++;
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
......
...@@ -45,4 +45,31 @@ struct FusedMoeTilePartitioner_PersistentSplitD ...@@ -45,4 +45,31 @@ struct FusedMoeTilePartitioner_PersistentSplitD
} }
}; };
template <typename FusedMoeTileShape_>
struct FusedMoeTilePartitioner_Linear
{
using Shape = ck_tile::remove_cvref_t<FusedMoeTileShape_>;
static constexpr const char* name = "2d"; // expert x hidden
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/,
ck_tile::index_t /*hidden_size*/))
{
index_t i_n = blockIdx.x;
index_t i_m = blockIdx.y;
return ck_tile::make_tuple(i_m, i_n);
}
// persistent
CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t hidden_size)
{
// TODO: this may need tuning
index_t grids = num_cu * blocks_per_cu;
index_t ms = ck_tile::integer_divide_ceil(max_tokens, Shape::kBlockM_0);
index_t ns = ck_tile::integer_divide_ceil(hidden_size, Shape::kBlockN_0);
return dim3(ns, ms, 1);
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -33,17 +33,7 @@ struct FusedMoePipelineProblem ...@@ -33,17 +33,7 @@ struct FusedMoePipelineProblem
static constexpr index_t kBlockSize = FusedMoeTileShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = FusedMoeTileShape::NumWarps * get_warp_size();
// attributes from traits
// static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
// static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
// static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
// static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
// static constexpr auto BiasEnum = Traits::BiasEnum;
// static constexpr bool kStoreLSE = Traits::kStoreLSE;
// static constexpr bool kHasDropout = Traits::kHasDropout;
// static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
using GateActivation = remove_cvref_t<typename Traits::GateActivation_>; // using GateActivation = remove_cvref_t<typename Traits::GateActivation_>;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -37,12 +37,11 @@ M_a | A | | | | | | | | | ...@@ -37,12 +37,11 @@ M_a | A | | | | | | | | |
SILU x-----x +----------+ SILU x-----x +----------+
K_y = N_g = N_u dim K_y = N_g = N_u dim
*/ */
template <typename BlockTile_, // sequence<M_a, N_g, N_u, K_a, N_d template <typename BlockTile_, // sequence<M_a, N_g, N_sub0, K_a, N_d
typename Gemm0BlockWarps_, typename Gemm0BlockWarps_,
typename Gemm0WarpTile_, typename Gemm0WarpTile_,
typename Gemm1BlockWarps_, typename Gemm1BlockWarps_,
typename Gemm1WarpTile_, typename Gemm1WarpTile_>
bool IsDLayoutRowMajor_>
struct FusedMoeTileShape struct FusedMoeTileShape
{ {
using BlockTile = remove_cvref_t<BlockTile_>; using BlockTile = remove_cvref_t<BlockTile_>;
...@@ -58,25 +57,60 @@ struct FusedMoeTileShape ...@@ -58,25 +57,60 @@ struct FusedMoeTileShape
static constexpr index_t kM_a = BlockTile::at(number<0>{}); static constexpr index_t kM_a = BlockTile::at(number<0>{});
static constexpr index_t kN_g = BlockTile::at(number<1>{}); static constexpr index_t kN_g = BlockTile::at(number<1>{});
static constexpr index_t kN_u = BlockTile::at(number<2>{}); static constexpr index_t kN_u = BlockTile::at(number<1>{});
static constexpr index_t kK_a = BlockTile::at(number<3>{}); // e.g. N_g = 256, n_sub_gu=128, then we split blockN of G/U into 2 parts to loopover
static constexpr index_t kN_d = BlockTile::at(number<4>{}); // this can help B matrix direct-to-register using too much vgpr issue
static_assert(kN_g == kN_u); static constexpr index_t kN_sub0 = BlockTile::at(number<2>{});
static constexpr index_t kK_a = BlockTile::at(number<3>{});
static constexpr index_t kN_d = BlockTile::at(number<4>{});
// static_assert(kN_g == kN_u);
static constexpr index_t kK_y = kN_g; static constexpr index_t kK_y = kN_g;
static constexpr index_t kM_0 = kM_a; static constexpr index_t kBlockNSub_0 = kN_sub0; // allow partial
static constexpr index_t kN_0 = kN_g; // note N will x2 static constexpr index_t kBlockM_0 = kM_a;
static constexpr index_t kK_0 = kK_a; static constexpr index_t kBlockN_0 = kN_g; // note N will x2 in real pipeline for gemm-0
static constexpr index_t kBlockK_0 = kK_a;
static constexpr index_t kWarpM_0 = Gemm0WarpTile::at(number<0>{});
static constexpr index_t kWarpN_0 = Gemm0WarpTile::at(number<1>{});
static constexpr index_t kWarpK_0 = Gemm0WarpTile::at(number<2>{});
static constexpr index_t kBlockWarpsM_0 = Gemm0BlockWarps::at(numner<0>{});
static constexpr index_t kBlockWarpsN_0 = Gemm0BlockWarps::at(numner<1>{});
static constexpr index_t kBlockWarpsK_0 = Gemm0BlockWarps::at(numner<2>{});
static constexpr index_t kSubBlockM_0 = kWarpM_0 * kBlockWarpsM_0;
static constexpr index_t kSubBlockN_0 = kWarpN_0 * kBlockWarpsN_0;
static constexpr index_t kSubBlockK_0 = kWarpK_0 * kBlockWarpsK_0;
static_assert(kBlockM_0 % kSubBlockM_0 == 0);
static_assert(kBlockN_0 % kSubBlockN_0 == 0);
static_assert(kBlockK_0 % kSubBlockK_0 == 0);
static constexpr index_t kWarpRepeatM_0 = kBlockM_0 / kSubBlockM_0;
static constexpr index_t kWarpRepeatN_0 = kBlockN_0 / kSubBlockN_0;
static constexpr index_t kWarpRepeatK_0 = kBlockK_0 / kSubBlockK_0;
static constexpr index_t kM_1 = kM_0; static constexpr index_t kBlockKSub_1 = kBlockNSub_0;
static constexpr index_t kN_1 = kN_d; static constexpr index_t kBlockM_1 = kM_a;
static constexpr index_t kK_1 = kN_g; static constexpr index_t kBlockN_1 = kN_d;
static constexpr index_t kBlockK_1 = kN_g;
static constexpr index_t kWarpM_1 = Gemm1WarpTile::at(number<0>{});
static constexpr index_t kWarpN_1 = Gemm1WarpTile::at(number<1>{});
static constexpr index_t kWarpK_1 = Gemm1WarpTile::at(number<2>{});
static constexpr index_t kBlockWarpsM_1 = Gemm1BlockWarps::at(numner<0>{});
static constexpr index_t kBlockWarpsN_1 = Gemm1BlockWarps::at(numner<1>{});
static constexpr index_t kBlockWarpsK_1 = Gemm1BlockWarps::at(numner<2>{});
static constexpr index_t kSubBlockM_1 = kWarpM_1 * kBlockWarpsM_1;
static constexpr index_t kSubBlockN_1 = kWarpN_1 * kBlockWarpsN_1;
static constexpr index_t kSubBlockK_1 = kWarpK_1 * kBlockWarpsK_1;
static_assert(kBlockM_1 % kSubBlockM_1 == 0);
static_assert(kBlockN_1 % kSubBlockN_1 == 0);
static_assert(kBlockK_1 % kSubBlockK_1 == 0);
static constexpr index_t kWarpRepeatM_1 = kBlockM_1 / kSubBlockM_1;
static constexpr index_t kWarpRepeatN_1 = kBlockN_1 / kSubBlockN_1;
static constexpr index_t kWarpRepeatK_1 = kBlockK_1 / kSubBlockK_1;
// d, rowmajor : hidden*dim, colmajor : dim*hidden (vLLM use this layout) // d, rowmajor : hidden*dim, colmajor : dim*hidden (vLLM use this layout)
static constexpr bool IsDLayoutRowMajor = IsDLayoutRowMajor_; // static constexpr bool IsDLayoutRowMajor = IsDLayoutRowMajor_;
using DLayout = std::conditional_t<IsDLayoutRowMajor, // using DLayout = std::conditional_t<IsDLayoutRowMajor,
ck_tile::tensor_layout::gemm::RowMajor, // ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>; // ck_tile::tensor_layout::gemm::ColumnMajor>;
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp"
namespace ck_tile {
template <bool DownPreShuffled_ = false,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */,
index_t OAtomic_ = 0,
FusedMoeWeightPermuteEnum WeightPermute_ =
FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv>
struct FusedMoeTraits
{
static constexpr bool DownPreShuffled = DownPreShuffled_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr FusedMoeWeightPermuteEnum WeightPermute = WeightPermute_;
static constexpr index_t OAtomic = OAtomic_; // 0-pack fp16/bf16 atomic, 1-fp32 atomic
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
enum class FusedMoeWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
no_permute = 999,
};
}
...@@ -616,11 +616,51 @@ struct buffer_store_if<1> ...@@ -616,11 +616,51 @@ struct buffer_store_if<1>
} }
}; };
CK_TILE_DEVICE void buffer_load_fence_raw(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
template <index_t cnt>
CK_TILE_DEVICE void buffer_load_fence_raw(number<cnt>)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
#if 0
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
{
// const index_t origin_cnt = 0x0f70;
__builtin_amdgcn_s_waitcnt(0x0f70 | cnt);
}
#endif
template <index_t cnt>
CK_TILE_DEVICE void buffer_load_fence(number<cnt>)
{
/*
simm16, simm16[3:0] -> bits[3:0], simm16[15:14] -> bits[5:4]
*/
static_assert(cnt < 64);
constexpr index_t low = cnt & 0xf;
constexpr index_t hi = (cnt & 0x30) << 14;
constexpr index_t c = 0x0f70 | low | hi;
__builtin_amdgcn_s_waitcnt(c);
}
CK_TILE_DEVICE void wave_barrier() { __builtin_amdgcn_s_barrier(); }
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
{ {
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
template <index_t cnt>
CK_TILE_DEVICE auto async_load_fence(number<cnt>)
{
buffer_load_fence(number<cnt>{});
}
// clang-format off // clang-format off
namespace impl{ namespace impl{
...@@ -706,13 +746,13 @@ CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by) ...@@ -706,13 +746,13 @@ CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by)
} }
// clang-format on // clang-format on
template <typename... T> template <typename... T>
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0, T&... o) CK_TILE_DEVICE void buffer_load_fence_raw(index_t cnt = 0, T&... o)
{ {
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
impl::insert_dummy_dep(o...); impl::insert_dummy_dep(o...);
} }
CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) CK_TILE_DEVICE void buffer_store_fence_raw(index_t cnt = 0)
{ {
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
...@@ -976,6 +1016,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, ...@@ -976,6 +1016,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <bool pre_nop = false> template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -998,10 +1048,12 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, ...@@ -998,10 +1048,12 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
: "memory"); : "memory");
} }
#if 0
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
{ {
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
#endif
// memory coherency bit for buffer store/load instruction // memory coherency bit for buffer store/load instruction
// check ISA manual for each GFX target // check ISA manual for each GFX target
...@@ -1365,6 +1417,45 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, ...@@ -1365,6 +1417,45 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem, // reinterpret_cast<CK_TILE_LDS_ADDR
// uint32_t*>(reinterpret_cast<uintptr_t>(smem)),
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem, // reinterpret_cast<CK_TILE_LDS_ADDR
// uint32_t*>(reinterpret_cast<uintptr_t>(smem)),
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
}
template <index_t N, template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default> amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data, CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
...@@ -2094,6 +2185,28 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, ...@@ -2094,6 +2185,28 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{}); smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
} }
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load<T, N, coherence>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
0,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
// buffer_store requires: // buffer_store requires:
// 1) p_dst_wave must point to global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
...@@ -2221,16 +2334,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_ ...@@ -2221,16 +2334,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif #endif
} }
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T, index_t NumElemsPerThread> template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset, const index_t global_offset,
......
...@@ -43,6 +43,20 @@ ...@@ -43,6 +43,20 @@
#define CK_TILE_HOST_DEVICE_EXTERN #define CK_TILE_HOST_DEVICE_EXTERN
#endif #endif
// implementing the "memory address space" attribute
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
#ifdef __HIPCC_
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
#else
#define CK_TILE_GENERIC_ADDR
#define CK_TILE_GLOBAL_ADDR
#define CK_TILE_LDS_ADDR
#define CK_TILE_BUF_RES_ADDR
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE #ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code #define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif #endif
......
...@@ -369,6 +369,31 @@ struct buffer_view<address_space_enum::global, ...@@ -369,6 +369,31 @@ struct buffer_view<address_space_enum::global,
dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{}); dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
} }
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto async_get(CK_TILE_LDS_ADDR remove_cvref_t<T>* smem,
index_t i,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
// X is vector of T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
smem, cached_buf_res_, i, is_valid_element, bool_constant<oob_conditional_check>{});
}
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool pre_nop = false, bool pre_nop = false,
......
...@@ -49,6 +49,26 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, ...@@ -49,6 +49,26 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
// while creating the smem window, which can enable compiler properly detect the
// dependency if using multiple smem window (multiple buffer)
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto
async_load_tile(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(lds_tile, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
...@@ -69,11 +89,6 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile, ...@@ -69,11 +89,6 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{}); lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename WindowLengths> template <typename WindowLengths>
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&) CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
{ {
......
...@@ -187,4 +187,30 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten ...@@ -187,4 +187,30 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
}); });
} }
namespace detail {
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template <typename X, typename Y>
struct is_similiar_distributed_tensor
{
static constexpr bool value = false;
};
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
static_distributed_tensor<TypeY, DistY>>
{
using Tx = static_distributed_tensor<TypeX, DistX>;
using Ty = static_distributed_tensor<TypeY, DistY>;
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
};
template <typename X, typename Y>
inline constexpr bool is_similiar_distributed_tensor_v =
is_similiar_distributed_tensor<X, Y>::value;
} // namespace detail
} // namespace ck_tile } // namespace ck_tile
...@@ -104,6 +104,23 @@ struct tensor_view ...@@ -104,6 +104,23 @@ struct tensor_view
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const
{
return buf_.template async_get<X>(
smem,
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}
template <typename X, template <typename X,
bool pre_nop = false, bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
......
...@@ -495,6 +495,74 @@ struct tile_window_with_static_distribution ...@@ -495,6 +495,74 @@ struct tile_window_with_static_distribution
}); });
} }
template <typename LdsTileWindow_, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
smem += size_per_issue; // Note we manually increase the per-issue offset
}
});
});
}
template <bool oob_conditional_check = true> template <bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor, CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {}) const
......
...@@ -39,7 +39,7 @@ struct Default2DEpilogue ...@@ -39,7 +39,7 @@ struct Default2DEpilogue
if constexpr(kPadM || kPadN) if constexpr(kPadM || kPadN)
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence(); buffer_store_fence_raw();
} }
else else
{ {
......
...@@ -274,8 +274,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -274,8 +274,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
store_tile(lse_acc_dram_window_tmp, store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc)); tile_elementwise_in(lse_acc_element_func, lse_acc));
} }
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) buffer_load_fence_raw(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?) // otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it // Note: here occ are all cleard, return it
return o_acc; return o_acc;
...@@ -315,7 +315,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -315,7 +315,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); buffer_load_fence_raw(k_dram_window.get_num_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q); // auto q_tile = q; // tile_elementwise_in(q_element_func, q);
...@@ -338,7 +338,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -338,7 +338,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access()); async_load_fence_raw(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc, gemm_0(s_acc,
...@@ -360,7 +360,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -360,7 +360,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
if constexpr(k0_loops <= 2) if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
async_load_fence(); async_load_fence_raw();
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment