"example/01_gemm/gemm_xdl_f8.cpp" did not exist on "acbd7bd7c5efd17b7061157a5868e28acc04d33e"
Commit e610402f authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp16 mfma

parent 4ea89209
......@@ -256,17 +256,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = reinterpret_cast<const half4_t*>(a);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_16x16x16f16(p_a, p_b, reg_c);
intrin_mfma_f32_16x16x16f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -289,17 +285,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = reinterpret_cast<const half4_t*>(a);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -322,17 +314,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = reinterpret_cast<const half4_t*>(a);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -596,43 +584,32 @@ struct XdlopsGemm
{
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32>{};
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 16>()
static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 64, 16, 1, 1, c_vec16_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64, 1, 1, c_vec16_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64, 1, 1, c_vec4_2_t>{};
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64, 1, 1, c_vec4_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16, 1, 1, c_vec4_1_t>{};
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64>{};
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
{
......
......@@ -414,60 +414,88 @@ struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
}
};
__device__ c_vec4_1_t::VecType
intrin_mfma_f32_16x16x16f16(const half4_t* reg_a, const half4_t* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x16f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_16x16x4f16(const half4_t* reg_a, const half4_t* reg_b, c_vec16_1_t::VecType reg_c);
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x16f16;
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x4f16<16, 64>(const half4_t* reg_a,
const half4_t* reg_b,
c_vec16_1_t::VecType reg_c)
template <index_t COffset>
struct intrin_mfma_f32_16x16x16f16<16, 16, COffset>
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x4f16<64, 16>(const half4_t* reg_a,
const half4_t* reg_b,
c_vec16_1_t::VecType reg_c)
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x4f16;
template <index_t COffset>
struct intrin_mfma_f32_16x16x4f16<16, 64, COffset>
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
return reg_c;
}
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave>
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_4x4x4f16;
template <>
struct intrin_mfma_f32_4x4x4f16<4, 64>
template <index_t COffset>
struct intrin_mfma_f32_4x4x4f16<4, 64, COffset>
{
__device__ static c_vec4_1_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec4_1_t::VecType reg_c)
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
return reg_c;
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
}
};
template <>
struct intrin_mfma_f32_4x4x4f16<8, 64>
template <index_t COffset>
struct intrin_mfma_f32_4x4x4f16<8, 64, COffset>
{
__device__ static c_vec4_2_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec4_2_t::VecType reg_c)
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
return reg_c;
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
}
};
......
......@@ -110,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmKPack = 8;
constexpr index_t GemmMPerWave = 4;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
constexpr index_t MRepeat = 16;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment