Commit c811a0e9 authored by aska-0096's avatar aska-0096
Browse files

temp save, add asm backend flag to amd_wmma

parent c749c262
......@@ -53,13 +53,13 @@ using DeviceConvFwdInstance =
GemmSpec, // GemmSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
256, // NPerBlock
4, // K0PerBlock
8, // K1
16, // MPerWMMA
16, // NPerWMMA
4, // MRepeat
2, // NRepeat
4, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
......
......@@ -375,7 +375,9 @@ template <index_t BlockSize,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
bool TransposeC = false,
bool AssemblyBackend = true>
/* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
......@@ -406,7 +408,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC, AssemblyBackend>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
......
......@@ -683,7 +683,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
NPerWmma,
MRepeat,
NRepeat,
KPack>{};
KPack,
false,
true>{};
// Prepare Register for C matrix
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......
......@@ -103,12 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
template <index_t MPerWmma, index_t NPerWmma, bool AssemblyBackend, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
{
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma, AssemblyBackend>::Run(a, b, reg_c);
}
else if constexpr(wave_size == 64)
{
......@@ -358,7 +358,8 @@ template <typename src_type_a,
index_t MPerWmma,
index_t NPerWmma,
index_t KPack,
bool TransposeC = false>
bool TransposeC = false,
bool AssemblyBackend = false>
struct WmmaGemm
{
static constexpr auto I0 = Number<0>{};
......@@ -491,11 +492,11 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(p_a_wave, p_b_wave, p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(p_b_wave, p_a_wave, p_c_thread);
}
}
......
......@@ -12,21 +12,23 @@ namespace ck {
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
template <index_t MPerWave, index_t NPerWave, bool AssemblyBackend>
struct intrin_wmma_f32_16x16x16_f16_w32;
template <>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
template <bool AssemblyBackend>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
if constexpr(AssemblyBackend){
amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
}
else{
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment