Commit e9f05865 authored by Jing Zhang's avatar Jing Zhang
Browse files

adjust xdlops

parent 82a15a27
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "math.hpp" #include "math.hpp"
#define CK_USE_AMD_XDLOPS_EMULATE 1
namespace ck { namespace ck {
enum struct mfma_instr enum struct mfma_instr
...@@ -839,56 +841,18 @@ struct XdlopsGemm_t ...@@ -839,56 +841,18 @@ struct XdlopsGemm_t
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
FloatA a[K];
FloatB b[K];
// load into registers
for(index_t k = 0; k < K; ++k)
{
a[k] = p_a_wave[k * M + laneId];
b[k] = p_b_wave[k * N + laneId];
}
// get pointer of registers
auto pa = reinterpret_cast<const data_type*>(&a);
auto pb = reinterpret_cast<const data_type*>(&b);
for(index_t k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
constexpr index_t nxdlops = sizeof(FloatA) / (mfma_type.k * sizeof(data_type)); mfma_type.run(Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread);
for(index_t i = 0; i < nxdlops; ++i, pa += mfma_type.k, pb += mfma_type.k)
mfma_type.run(Number<MPerWave>{}, Number<NPerWave>{}, pa, pb, p_c_thread);
} }
}).Else([&](auto) { }).Else([&](auto) {
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
FloatA a[K];
FloatB b[K];
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
// load into registers
for(index_t k = 0; k < K; k += mfma_type.num_input_blks) for(index_t k = 0; k < K; k += mfma_type.num_input_blks)
{ {
a[k] = p_a_wave[(k + blk_id) * M + blk_td]; mfma_type.run(Number<MPerWave>{}, Number<NPerWave>{}, p_a_wave, p_b_wave, p_c_thread);
b[k] = p_b_wave[(k + blk_id) * N + blk_td];
}
// get pointer of registers
auto pa = reinterpret_cast<const data_type*>(&a);
auto pb = reinterpret_cast<const data_type*>(&b);
constexpr index_t nxdlops =
(sizeof(FloatA) * mfma_type.num_input_blks) / (mfma_type.k * sizeof(data_type));
for(index_t k = 0; k < K; k += mfma_type.num_input_blks)
{
for(index_t i = 0; i < nxdlops; ++i, pa += mfma_type.k, pb += mfma_type.k)
mfma_type.run(Number<MPerWave>{}, Number<NPerWave>{}, pa, pb, p_c_thread);
} }
}); });
...@@ -942,7 +906,7 @@ struct XdlopsGemm_t ...@@ -942,7 +906,7 @@ struct XdlopsGemm_t
__device__ void ReadXdlopsRegs(Number<Size>, FloatC* const __restrict__ p_c_thread) const __device__ void ReadXdlopsRegs(Number<Size>, FloatC* const __restrict__ p_c_thread) const
{ {
#if !CK_USE_AMD_XDLOPS_EMULATE #if !CK_USE_AMD_XDLOPS_EMULATE
//constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>(); constexpr auto mfma_type = GetMFMAInfo<data_type, MPerWave, NPerWave>();
//gcnasm_nop<mfma_type.cycles>(); //gcnasm_nop<mfma_type.cycles>();
//gcnasm_accvgpr_read<Size>(p_c_thread); //gcnasm_accvgpr_read<Size>(p_c_thread);
#else #else
......
...@@ -12,8 +12,7 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const flo ...@@ -12,8 +12,7 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const flo
auto reg_c_ = reinterpret_cast<float_t*>(reg_c); auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 32; i++) for(index_t i = 0; i < 32; i++)
{ {
reg_c_[i] += reg_a * reg_b; reg_c_[i + 32] = reg_c_[i] = reg_c_[i] + reg_a * reg_b;
reg_c_[i+32] = reg_c[i];
} }
} }
...@@ -24,7 +23,6 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const flo ...@@ -24,7 +23,6 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const flo
for(index_t i = 0; i < 16; i++) for(index_t i = 0; i < 16; i++)
{ {
reg_c_[i] += reg_a * reg_b; reg_c_[i] += reg_a * reg_b;
reg_c_[i+16] = reg_c[i];
} }
} }
...@@ -35,7 +33,6 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<64, 32>(const float& reg_a, const flo ...@@ -35,7 +33,6 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<64, 32>(const float& reg_a, const flo
for(index_t i = 0; i < 16; i++) for(index_t i = 0; i < 16; i++)
{ {
reg_c_[i] += reg_a * reg_b; reg_c_[i] += reg_a * reg_b;
reg_c_[i+16] = reg_c[i];
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment