Commit f09dc1f3 authored by carlushuang's avatar carlushuang
Browse files

compiler ok

parent 3bb718ad
......@@ -11,5 +11,7 @@ set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS})
......@@ -62,6 +62,7 @@
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
......
......@@ -888,6 +888,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// buffer load i8
CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
template <typename Context, index_t Start = 0, index_t Step = 1>
struct static_counter
{
public:
template <typename Unique>
static constexpr index_t next()
{
return next<Unique>(0) * Step + Start;
}
template <unsigned long long>
static constexpr index_t next()
{
struct Unique
{
};
return next<Unique>(0) * Step + Start;
}
template <typename Unique>
static constexpr index_t current()
{
return current<Unique>(0) * Step + Start;
}
template <unsigned long long>
static constexpr index_t current()
{
struct Unique
{
};
return current<Unique>(0) * Step + Start;
}
private:
template <index_t I>
struct slot
{
_Pragma("GCC diagnostic push");
_Pragma("GCC diagnostic ignored \"-Wundefined-internal\"");
friend constexpr bool slot_allocated(slot<I>);
_Pragma("GCC diagnostic pop");
};
template <index_t I>
struct allocate_slot
{
friend constexpr bool slot_allocated(slot<I>) { return true; }
enum
{
value = I
};
};
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template <typename Unique, index_t I = 0, bool = slot_allocated(slot<I>())>
static constexpr index_t next(index_t)
{
return next<Unique, I + 1>(0);
}
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
// allocate_slot<I>.
template <typename Unique, index_t I = 0>
static constexpr index_t next(double)
{
return allocate_slot<I>::value;
}
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template <typename Unique, index_t I = Start, bool = slot_allocated(slot<I>())>
static constexpr index_t current(index_t)
{
return current<Unique, I + 1>(0);
}
// ...And this function will be used, instead, which will return the current counter, or assert
// in case next() hasn't been called yet.
template <typename Unique, index_t I = Start>
static constexpr index_t current(double)
{
static_assert(I != 0, "You must invoke next() first");
return I - 1;
}
};
namespace impl {
template <int I>
struct static_counter_uniq_;
}
#define MAKE_SC() \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
#define MAKE_SC_WITH(start_, step_) \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
#define NEXT_SC(c_) c_.next<__COUNTER__>()
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
// Usage:
// constexpr auto c = MAKE_SC()
// NEXT_SC(c) // -> constexpr 0
// NEXT_SC(c) // -> constexpr 1
// NEXT_SC(c) // -> constexpr 2
} // namespace ck_tile
......@@ -156,10 +156,10 @@ struct FusedMoeGemmPipeline_Flatmm
using g_thread_type = decltype(load_tile(g_win));
using d_thread_type = decltype(load_tile(d_win));
// using WarpGemm0 = Policy::template GetWarpGemm0<Problem>();
// using WarpGemm1 = Policy::template GetWarpGemm1<Problem>();
// auto warp_gemm_0 = WarpGemm0{};
// auto warp_gemm_1 = WarpGemm1{};
using WarpGemm0 = decltype(Policy::template GetWarpGemm0<Problem>());
using WarpGemm1 = decltype(Policy::template GetWarpGemm1<Problem>());
auto warp_gemm_0 = WarpGemm0{};
auto warp_gemm_1 = WarpGemm1{};
// issues_warps_lanes
auto a_sst_win0 =
......@@ -175,7 +175,7 @@ struct FusedMoeGemmPipeline_Flatmm
{0, 0, 0});
// m*k
auto a_sld_win0 = [&]() {
using WG = decltype(Policy::template GetWarpGemm0<Problem>());
using WG = WarpGemm0;
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<>,
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
......@@ -196,7 +196,7 @@ struct FusedMoeGemmPipeline_Flatmm
// m*k
auto a_sld_win1 = [&]() {
using WG = decltype(Policy::template GetWarpGemm0<Problem>());
using WG = WarpGemm0;
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<>,
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
......@@ -242,10 +242,12 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr auto issues_d = number<d_win.get_num_of_access()>{};
constexpr auto issues_o = number<o_win.get_num_of_access()>{};
constexpr auto issues_gemm0 =
number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0>{};
number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
warp_gemm_0.get_num_of_access()>{};
constexpr auto issues_gemm1 =
number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1>{};
constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1 *
warp_gemm_1.get_num_of_access()>{};
// constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
const index_t num_blocks_k0 =
(hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
......@@ -284,10 +286,8 @@ struct FusedMoeGemmPipeline_Flatmm
}
load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
};
auto move_g =
[&]() {
move_tile_window(g_win,
{number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
auto move_g = [&]() {
move_tile_window(g_win, {number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
};
statically_indexed_array<d_thread_type, 2> ds;
......@@ -314,16 +314,17 @@ struct FusedMoeGemmPipeline_Flatmm
// clang-format off
auto gemm_0 = [&]<typename PostNop = bool_constant<false>>
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
auto warp_gemm = Policy::template GetWarpGemm0<Problem>();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
using WarpGemm = remove_cvref_t<decltype(warp_gemm_0)>;
constexpr auto repeat_sub = WarpGemm::get_num_of_access();
constexpr auto repeat_m = BlockShape::Repeat_M0;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr auto repeat_k = BlockShape::Repeat_K0;
// loop order n->m->k
constexpr auto i_k = i_access % repeat_k;
constexpr auto i_m = (i_access / repeat_k) % repeat_m;
constexpr auto i_n = (i_access / repeat_k) / repeat_m;
constexpr auto i_sub = i_access % repeat_sub;
constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
......@@ -355,7 +356,7 @@ struct FusedMoeGemmPipeline_Flatmm
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
WarpGemm{}(w_c, w_a, w_b, PostNop{});
warp_gemm_0(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
t_c.set_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
......@@ -367,16 +368,17 @@ struct FusedMoeGemmPipeline_Flatmm
// clang-format off
auto gemm_1 = [&]<typename PostNop = bool_constant<false>>
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
auto warp_gemm = Policy::template GetWarpGemm1<Problem>();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
using WarpGemm = remove_cvref_t<decltype(warp_gemm_1)>;
constexpr auto repeat_m = BlockShape::Repeat_M1;
// constexpr auto repeat_n = BlockShape::Repeat_N1;
constexpr auto repeat_k = BlockShape::Repeat_K1;
constexpr auto repeat_sub = WarpGemm::get_num_of_access();
constexpr auto repeat_m = BlockShape::Repeat_M0;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr auto repeat_k = BlockShape::Repeat_K0;
// loop order n->m->k
constexpr auto i_k = i_access % repeat_k;
constexpr auto i_m = (i_access / repeat_k) % repeat_m;
constexpr auto i_n = (i_access / repeat_k) / repeat_m;
constexpr auto i_sub = i_access % repeat_sub;
constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
......@@ -408,7 +410,7 @@ struct FusedMoeGemmPipeline_Flatmm
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
WarpGemm{}(w_c, w_a, w_b, PostNop{});
warp_gemm_1(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
t_c.set_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
......@@ -416,7 +418,7 @@ struct FusedMoeGemmPipeline_Flatmm
w_c.get_thread_buffer());
};
// clang-format on
_Pragma("clang diagnostic pop")
_Pragma("clang diagnostic pop");
// this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
// be hide under mfma. In other words, issues of mfma is >= memory this is true if we
......@@ -427,61 +429,49 @@ struct FusedMoeGemmPipeline_Flatmm
// mfma(that can reuse the B matrix) only affected by M repeat.
auto pipeline_gemm0 = [&]() {
constexpr index_t total_loops = issues_gemm0;
constexpr index_t mfma_per_ld = total_loops / (issues_g + issues_a + issues_sld_a);
// compute buffer 0
constexpr auto sr = Policy::template GetSequencer_0<Problem>();
static_assert(sr.size() == total_loops);
constexpr index_t SLD_A =
static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
constexpr index_t GLD_A =
static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
constexpr index_t GLD_B =
static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
constexpr auto c_sld_a_0 = MAKE_SC();
constexpr auto c_gld_a_0 = MAKE_SC();
constexpr auto c_gld_b_0 = MAKE_SC();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue);
if constexpr(i_issue % mfma_per_ld == 0)
{
constexpr index_t ld_id = 0;
if constexpr(ld_id < issues_g)
{
gld_g(gs[I0], number<ld_id>{});
}
if constexpr(ld_id - issues_g < +issues_a)
{
gld_a(a_sst_win0, number<ld_id - issues_g>{});
}
if constexpr(ld_id - issues_g - issues_a < issues_sld_a)
{
sld_a(as[I1], a_sld_win1, number<ld_id - issues_g - issues_a>{});
}
ld_id++;
}
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & SLD_A)
sld_a(as[I1], a_sld_win1, number<NEXT_SCI(c_sld_a_0, i_issue)>{});
if constexpr(slot & GLD_A)
gld_a(a_sst_win0, number<NEXT_SCI(c_gld_a_0, i_issue)>{});
if constexpr(slot & GLD_B)
gld_g(gs[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
});
move_g();
move_a();
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
constexpr auto c_sld_a_1 = MAKE_SC();
constexpr auto c_gld_a_1 = MAKE_SC();
constexpr auto c_gld_b_1 = MAKE_SC();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I1], gs[I1], i_issue);
if constexpr(i_issue % mfma_per_ld == 0)
{
constexpr index_t ld_id = 0;
if constexpr(ld_id < issues_g)
{
gld_g(gs[I1], number<ld_id>{});
}
if constexpr(ld_id - issues_g < +issues_a)
{
gld_a(a_sst_win1, number<ld_id - issues_g>{});
}
if constexpr(ld_id - issues_g - issues_a < issues_sld_a)
{
sld_a(as[I0], a_sld_win0, number<ld_id - issues_g - issues_a>{});
}
ld_id++;
}
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & SLD_A)
sld_a(as[I0], a_sld_win0, number<NEXT_SCI(c_sld_a_1, i_issue)>{});
if constexpr(slot & GLD_A)
gld_a(a_sst_win1, number<NEXT_SCI(c_gld_a_1, i_issue)>{});
if constexpr(slot & GLD_B)
gld_g(gs[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
});
move_g();
move_a();
......@@ -493,7 +483,7 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr index_t total_loops = issues_gemm0;
constexpr index_t mfma_per_gld_g = total_loops / issues_g; // BlockShape::Repeat_M0;
// constexpr index_t mfma_per_gld_a = total_loops / issues_a;
constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a;
// constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a;
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
......@@ -515,7 +505,7 @@ struct FusedMoeGemmPipeline_Flatmm
});
// if cycle_mfma>gld_a sync here
block_sync_load_raw(issues_g);
sld_a(as[I1], a_sld_win1, NEG1{});
sld_a(as[I1], a_sld_win1, NEG1);
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
......
......@@ -609,11 +609,45 @@ struct FusedMoeGemmPipelineFlatmmPolicy
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_0()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using S_ = typename Problem::BlockShape;
constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 0
GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 3
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
return seq_all;
// clang-format on
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
{
using S_ = typename Problem::BlockShape;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vva;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_vav;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
......
......@@ -33,4 +33,15 @@ struct FusedMoeGemmTraits
static constexpr bool PadHiddenSize = PadHiddenSize_;
static constexpr bool PadIntermediateSize = PadIntermediateSize_;
};
// Note: this need to be a bit mask
enum class FusedMoeGemmPipelineSequencerEnum
{
SLD_A = 1 << 0, // shared load a
SLD_B = 1 << 1,
GLD_A = 1 << 2, // global load a
GLD_B = 1 << 3,
SST_A = 1 << 4, // shared store a
SST_B = 1 << 5,
};
} // namespace ck_tile
......@@ -25,6 +25,8 @@ struct WarpGemmAtrributeMfma
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
......@@ -88,6 +90,8 @@ struct WarpGemmAtrributeMfmaIterateK
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
......@@ -197,6 +201,8 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
......@@ -258,6 +264,8 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
......@@ -326,6 +334,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
......@@ -439,6 +449,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
......@@ -576,6 +588,8 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
......
......@@ -24,7 +24,7 @@ enum class WGAttrCtlEnum
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
if constexpr(post_nop_) \
{ \
asm volatile(mfma_ " %0, %1, %2, %3\n" \
asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
"s_nop 3" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
......
......@@ -31,6 +31,11 @@ struct WarpGemmImpl
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
{
return WarpGemmAttribute_::get_num_of_access();
}
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
CK_TILE_DEVICE void
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment