Commit 4ea89209 authored by Jing Zhang's avatar Jing Zhang
Browse files

add 32x32x8fp16

parent 822856e1
...@@ -227,17 +227,13 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16> ...@@ -227,17 +227,13 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
index_t AStride, index_t COffset,
index_t BStride,
class FloatA, class FloatA,
class FloatB, class FloatB,
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
const auto p_a = reinterpret_cast<const half4_t*>(a); intrin_mfma_f32_32x32x8f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_32x32x8f16(p_a, p_b, reg_c);
} }
}; };
...@@ -589,19 +585,18 @@ struct XdlopsGemm ...@@ -589,19 +585,18 @@ struct XdlopsGemm
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64>{};
} }
#if 0
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 32>() static constexpr auto GetXdlopsInfo<half_t, 32, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 32, 1, 1, c_vec32_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 64>() static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64, 1, 1, c_vec32_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32>{};
} }
#if 0
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 16>() static constexpr auto GetXdlopsInfo<half_t, 64, 16>()
{ {
...@@ -759,12 +754,14 @@ struct XdlopsGemm ...@@ -759,12 +754,14 @@ struct XdlopsGemm
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops();
static_for<0, KPack / mfma_type.k_base, 1>{}([&](auto k) { static_for<0, KPack, mfma_type.k_base>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k)); constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k)); constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k));
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>( mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[Number<a_offset>{}], p_b_wave[Number<b_offset>{}], p_c_thread); p_a_wave[Number<a_offset / mfma_type.k_base>{}],
p_b_wave[Number<b_offset / mfma_type.k_base>{}],
p_c_thread);
}); });
} }
......
...@@ -394,12 +394,25 @@ struct intrin_mfma_f32_32x32x4f16<32, 64, COffset> ...@@ -394,12 +394,25 @@ struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
} }
}; };
__device__ c_vec16_1_t::VecType template <index_t MPerWave, index_t NPerWave, index_t COffset>
intrin_mfma_f32_32x32x8f16(const half4_t* reg_a, const half4_t* reg_b, c_vec16_1_t::VecType reg_c) struct intrin_mfma_f32_32x32x8f16;
template <index_t COffset>
struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x8f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); template <class FloatC>
return reg_c; __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
} {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
}
};
__device__ c_vec4_1_t::VecType __device__ c_vec4_1_t::VecType
intrin_mfma_f32_16x16x16f16(const half4_t* reg_a, const half4_t* reg_b, c_vec4_1_t::VecType reg_c) intrin_mfma_f32_16x16x16f16(const half4_t* reg_a, const half4_t* reg_b, c_vec4_1_t::VecType reg_c)
......
...@@ -110,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -110,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmKPack = 4; constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 1; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment