Commit c811a0e9 authored by aska-0096's avatar aska-0096
Browse files

temp save, add asm backend flag to amd_wmma

parent c749c262
...@@ -53,13 +53,13 @@ using DeviceConvFwdInstance = ...@@ -53,13 +53,13 @@ using DeviceConvFwdInstance =
GemmSpec, // GemmSpecialization GemmSpec, // GemmSpecialization
256, // BlockSize 256, // BlockSize
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 256, // NPerBlock
4, // K0PerBlock 4, // K0PerBlock
8, // K1 8, // K1
16, // MPerWMMA 16, // MPerWMMA
16, // NPerWMMA 16, // NPerWMMA
4, // MRepeat 4, // MRepeat
2, // NRepeat 4, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
......
...@@ -375,7 +375,9 @@ template <index_t BlockSize, ...@@ -375,7 +375,9 @@ template <index_t BlockSize,
index_t NPerWMMA, index_t NPerWMMA,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack> index_t KPack,
bool TransposeC = false,
bool AssemblyBackend = true>
/* A: K0PerBlock x MPerBlock x K1 /* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1 * B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
...@@ -406,7 +408,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -406,7 +408,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm = static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{}; WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC, AssemblyBackend>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
......
...@@ -683,7 +683,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -683,7 +683,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
NPerWmma, NPerWmma,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>{}; KPack,
false,
true>{};
// Prepare Register for C matrix // Prepare Register for C matrix
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......
...@@ -103,12 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16, ...@@ -103,12 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC> template <index_t MPerWmma, index_t NPerWmma, bool AssemblyBackend, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
if constexpr(wave_size == 32) if constexpr(wave_size == 32)
{ {
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c); intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma, AssemblyBackend>::Run(a, b, reg_c);
} }
else if constexpr(wave_size == 64) else if constexpr(wave_size == 64)
{ {
...@@ -358,7 +358,8 @@ template <typename src_type_a, ...@@ -358,7 +358,8 @@ template <typename src_type_a,
index_t MPerWmma, index_t MPerWmma,
index_t NPerWmma, index_t NPerWmma,
index_t KPack, index_t KPack,
bool TransposeC = false> bool TransposeC = false,
bool AssemblyBackend = false>
struct WmmaGemm struct WmmaGemm
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -491,11 +492,11 @@ struct WmmaGemm ...@@ -491,11 +492,11 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!"); "(int8, int32) or (int4, int32)!");
if constexpr(!TransposeC) if constexpr(!TransposeC)
{ {
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread); wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(p_a_wave, p_b_wave, p_c_thread);
} }
else else
{ {
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread); wmma_instr.template run<MPerWmma, NPerWmma, AssemblyBackend>(p_b_wave, p_a_wave, p_c_thread);
} }
} }
......
...@@ -12,21 +12,23 @@ namespace ck { ...@@ -12,21 +12,23 @@ namespace ck {
/********************************WAVE32 MODE***********************************************/ /********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32 // src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave, bool AssemblyBackend>
struct intrin_wmma_f32_16x16x16_f16_w32; struct intrin_wmma_f32_16x16x16_f16_w32;
template <> template <bool AssemblyBackend>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{ {
// * Inline assembly need to elimate the duplicated data load, compiler won't help you if constexpr(AssemblyBackend){
// delete them. amd_assembly_wmma_f32_16x16x16_f16_w32(
// amd_assembly_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{})); }
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( else{
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]); reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment