Commit 63bad606 authored by Jing Zhang's avatar Jing Zhang
Browse files

demo of removing array for A/B in xdlops

parent 494608ce
...@@ -123,12 +123,9 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32> ...@@ -123,12 +123,9 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
class FloatA, class FloatA,
class FloatB, class FloatB,
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ FloatC run(const FloatA a, const FloatB b, FloatC reg_c) const
{ {
const auto p_a = reinterpret_cast<const float*>(a); return intrin_mfma_f32_16x16x4f32(a, b, reg_c);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_16x16x4f32(p_a, p_b, reg_c);
} }
}; };
...@@ -708,6 +705,12 @@ struct XdlopsGemm_t ...@@ -708,6 +705,12 @@ struct XdlopsGemm_t
} }
#endif #endif
template <class FloatAB>
__device__ static auto lds_load(const FloatAB* p_src, const index_t src_offset)
{
return p_src[src_offset];
}
template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC> template <index_t M, index_t N, index_t K, class FloatA, class FloatB, class FloatC>
__device__ FloatC Run(const FloatA* const __restrict__ p_a_wave, __device__ FloatC Run(const FloatA* const __restrict__ p_a_wave,
const FloatB* const __restrict__ p_b_wave, const FloatB* const __restrict__ p_b_wave,
...@@ -727,6 +730,13 @@ struct XdlopsGemm_t ...@@ -727,6 +730,13 @@ struct XdlopsGemm_t
FloatA a[K * MRepeats]; FloatA a[K * MRepeats];
FloatB b[K * NRepeats]; FloatB b[K * NRepeats];
constexpr index_t data_size = sizeof(FloatA) / sizeof(data_type);
constexpr index_t a_reg_buff_size = K * MRepeats * data_size;
constexpr index_t b_reg_buff_size = K * NRepeats * data_size;
auto reg_a = GetRegBuffer<data_type, a_reg_buff_size>();
auto reg_b = GetRegBuffer<data_type, b_reg_buff_size>();
static_assert(sizeof(FloatA) % (sizeof(data_type) * mfma_type.k_base) == 0, static_assert(sizeof(FloatA) % (sizeof(data_type) * mfma_type.k_base) == 0,
"wrong! FloatA is consistent with mfma"); "wrong! FloatA is consistent with mfma");
...@@ -769,25 +779,39 @@ struct XdlopsGemm_t ...@@ -769,25 +779,39 @@ struct XdlopsGemm_t
const index_t blk_id = laneId / mfma_type.num_threads_blk; const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk; const index_t blk_td = laneId % mfma_type.num_threads_blk;
#if 0
// load into registers // load into registers
for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks) for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
{ {
a[k_i] = p_a_wave[(k_i + blk_id) * M + blk_td]; a[k_i] = p_a_wave[(k_i + blk_id) * M + blk_td];
b[k_i] = p_b_wave[(k_i + blk_id) * N + blk_td]; b[k_i] = p_b_wave[(k_i + blk_id) * N + blk_td];
} }
#if CK_WORKAROUND_SWDEV_229564
#pragma unroll
#endif #endif
for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
{ static_for<0, K, mfma_type.num_input_blks>{}([&](auto k_i) {
for(index_t i = 0; i < KRepeats; ++i) index_t a_offset = (k_i + blk_id) * M + blk_td;
reg_a.GetVector(Number<data_size>{})(Number<k_i>{}) =
lds_load(p_a_wave, a_offset);
index_t b_offset = (k_i + blk_id) * N + blk_td;
reg_b.GetVector(Number<data_size>{})(Number<k_i>{}) =
lds_load(p_b_wave, b_offset);
});
// for(index_t k_i = 0; k_i < K; k_i += mfma_type.num_input_blks)
// for(index_t i = 0; i < KRepeats; ++i)
static_for<0, K, mfma_type.num_input_blks>{}([&](auto k_i) {
static_for<0, KRepeats, 1>{}([&](auto i) {
constexpr index_t offset = k_i * KRepeats + i;
p_c_thread = p_c_thread =
mfma_type.template run<MPerXdlops, NPerXdlops, AStride, BStride>( mfma_type.template run<MPerXdlops, NPerXdlops, AStride, BStride>(
&pa[(k_i * KRepeats + i) * mfma_type.k_base], reg_a.GetVector(Number<mfma_type.k_base>{})[Number<offset>{}],
&pb[(k_i * KRepeats + i) * mfma_type.k_base], reg_b.GetVector(Number<mfma_type.k_base>{})[Number<offset>{}],
p_c_thread); p_c_thread);
} });
});
}); });
#endif #endif
......
...@@ -132,12 +132,12 @@ intrin_mfma_f32_32x32x2f32(const float* reg_a, const float* reg_b, c_vec16_1_t:: ...@@ -132,12 +132,12 @@ intrin_mfma_f32_32x32x2f32(const float* reg_a, const float* reg_b, c_vec16_1_t::
return reg_c; return reg_c;
} }
__device__ float_vec4_t intrin_mfma_f32_16x16x4f32(const float* reg_a, __device__ float_vec4_t intrin_mfma_f32_16x16x4f32(const float reg_a,
const float* reg_b, const float reg_b,
float_vec4_t reg_c) float_vec4_t reg_c)
{ {
reg_c.s4(Number<0>{}) = reg_c.s4(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(reg_a[0], reg_b[0], reg_c.s4[Number<0>{}], 0, 0, 0); llvm_intrin_amdgcn_mfma_f32_16x16x4f32(reg_a, reg_b, reg_c.s4[Number<0>{}], 0, 0, 0);
return reg_c; return reg_c;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment