Commit 4ea89209 authored by Jing Zhang's avatar Jing Zhang
Browse files

add 32x32x8fp16

parent 822856e1
......@@ -227,17 +227,13 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = reinterpret_cast<const half4_t*>(a);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_32x32x8f16(p_a, p_b, reg_c);
intrin_mfma_f32_32x32x8f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -589,19 +585,18 @@ struct XdlopsGemm
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64>{};
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 32>()
static constexpr auto GetXdlopsInfo<half_t, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 32, 1, 1, c_vec32_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 64>()
static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64, 1, 1, c_vec32_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32>{};
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 16>()
{
......@@ -759,12 +754,14 @@ struct XdlopsGemm
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops();
static_for<0, KPack / mfma_type.k_base, 1>{}([&](auto k) {
static_for<0, KPack, mfma_type.k_base>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k));
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[Number<a_offset>{}], p_b_wave[Number<b_offset>{}], p_c_thread);
p_a_wave[Number<a_offset / mfma_type.k_base>{}],
p_b_wave[Number<b_offset / mfma_type.k_base>{}],
p_c_thread);
});
}
......
......@@ -394,12 +394,25 @@ struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
}
};
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_32x32x8f16(const half4_t* reg_a, const half4_t* reg_b, c_vec16_1_t::VecType reg_c)
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x8f16;
template <index_t COffset>
struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x8f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
}
};
__device__ c_vec4_1_t::VecType
intrin_mfma_f32_16x16x16f16(const half4_t* reg_a, const half4_t* reg_b, c_vec4_1_t::VecType reg_c)
......
......@@ -110,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment