Commit e610402f authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp16 mfma

parent 4ea89209
...@@ -256,17 +256,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16> ...@@ -256,17 +256,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
index_t AStride, index_t COffset,
index_t BStride,
class FloatA, class FloatA,
class FloatB, class FloatB,
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
const auto p_a = reinterpret_cast<const half4_t*>(a); intrin_mfma_f32_16x16x16f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_16x16x16f16(p_a, p_b, reg_c);
} }
}; };
...@@ -289,17 +285,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16> ...@@ -289,17 +285,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
index_t AStride, index_t COffset,
index_t BStride,
class FloatA, class FloatA,
class FloatB, class FloatB,
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
const auto p_a = reinterpret_cast<const half4_t*>(a); intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
} }
}; };
...@@ -322,17 +314,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16> ...@@ -322,17 +314,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
index_t AStride, index_t COffset,
index_t BStride,
class FloatA, class FloatA,
class FloatB, class FloatB,
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
const auto p_a = reinterpret_cast<const half4_t*>(a); intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
} }
}; };
...@@ -596,43 +584,32 @@ struct XdlopsGemm ...@@ -596,43 +584,32 @@ struct XdlopsGemm
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32>{};
} }
#if 0
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 16>() static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 64, 16, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 64>() static constexpr auto GetXdlopsInfo<half_t, 16, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64, 1, 1, c_vec16_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 8, 64>() static constexpr auto GetXdlopsInfo<half_t, 8, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64, 1, 1, c_vec4_2_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 4, 64>() static constexpr auto GetXdlopsInfo<half_t, 4, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64, 1, 1, c_vec4_1_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16, 1, 1, c_vec4_1_t>{};
} }
#if 0
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 128, 64>() static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
{ {
......
...@@ -414,60 +414,88 @@ struct intrin_mfma_f32_32x32x8f16<32, 32, COffset> ...@@ -414,60 +414,88 @@ struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
} }
}; };
__device__ c_vec4_1_t::VecType template <index_t MPerWave, index_t NPerWave, index_t COffset>
intrin_mfma_f32_16x16x16f16(const half4_t* reg_a, const half4_t* reg_b, c_vec4_1_t::VecType reg_c) struct intrin_mfma_f32_16x16x16f16;
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x16f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_16x16x4f16(const half4_t* reg_a, const half4_t* reg_b, c_vec16_1_t::VecType reg_c);
template <> template <index_t COffset>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x4f16<16, 64>(const half4_t* reg_a, struct intrin_mfma_f32_16x16x16f16<16, 16, COffset>
const half4_t* reg_b,
c_vec16_1_t::VecType reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0); template <class FloatC>
return reg_c; __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
} {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <> template <index_t MPerWave, index_t NPerWave, index_t COffset>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x4f16<64, 16>(const half4_t* reg_a, struct intrin_mfma_f32_16x16x4f16;
const half4_t* reg_b,
c_vec16_1_t::VecType reg_c)
template <index_t COffset>
struct intrin_mfma_f32_16x16x4f16<16, 64, COffset>
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4); template <class FloatC>
return reg_c; __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
} {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_4x4x4f16; struct intrin_mfma_f32_4x4x4f16;
template <> template <index_t COffset>
struct intrin_mfma_f32_4x4x4f16<4, 64> struct intrin_mfma_f32_4x4x4f16<4, 64, COffset>
{ {
__device__ static c_vec4_1_t::VecType template <class FloatC>
run(const half4_t* reg_a, const half4_t* reg_b, c_vec4_1_t::VecType reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
return reg_c; llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
} }
}; };
template <> template <index_t COffset>
struct intrin_mfma_f32_4x4x4f16<8, 64> struct intrin_mfma_f32_4x4x4f16<8, 64, COffset>
{ {
__device__ static c_vec4_2_t::VecType template <class FloatC>
run(const half4_t* reg_a, const half4_t* reg_b, c_vec4_2_t::VecType reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0); llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
return reg_c; reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
} }
}; };
......
...@@ -110,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -110,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerWave = 4;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8; constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 16;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment