Commit 5c27dcd5 authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp32 mfma instructions

parent 21755b5d
...@@ -11,7 +11,8 @@ namespace ck { ...@@ -11,7 +11,8 @@ namespace ck {
// GemmM = K // GemmM = K
// GemmN = N * Ho * Wo // GemmN = N * Ho * Wo
// GemmK = C * Y * X // GemmK = C * Y * X
template <index_t GemmMPerBlock, template <typename FloatAB,
index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmMPerWave, index_t GemmMPerWave,
index_t GemmNPerWave, index_t GemmNPerWave,
...@@ -109,7 +110,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -109,7 +110,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<float, GemmMPerWave, GemmNPerWave, GemmKPerWave>{}; constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPerWave>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout(); constexpr auto CLayout = xdlops_gemm.GetCLayout();
......
...@@ -34,7 +34,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -34,7 +34,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatA, MPerWave, NPerWave, KPack>{};
static constexpr index_t MWaves = M1 / MPerWave; static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave; static constexpr index_t NWaves = N1 / NPerWave;
......
...@@ -306,14 +306,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -306,14 +306,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(Sequence<0, 3>{}, Sequence<1, 2>{})); make_tuple(Sequence<0, 3>{}, Sequence<1, 2>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline<BlockSize, BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k0_m0_m1_k1_block_desc), decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc), decltype(b_k0_n0_n1_k1_block_desc),
MPerWave, MPerWave,
NPerWave, NPerWave,
KPack>{}; KPack>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout(); constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize(); constexpr index_t BlkSize = CLayout.GetBlkSize();
...@@ -327,7 +327,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -327,7 +327,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(Number<NumBlks>{}, Number<BlkSize>{})); make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
StaticBuffer<AddressSpace::Vgpr, StaticBuffer<AddressSpace::Vgpr,
vector_type<float, c_blk_nb_bs_desc.GetElementSpaceSize()>, vector_type<FloatAB, c_blk_nb_bs_desc.GetElementSpaceSize()>,
c_mr_nr_nx_desc.GetElementSpaceSize()> c_mr_nr_nx_desc.GetElementSpaceSize()>
c_thread_buf; c_thread_buf;
...@@ -488,7 +488,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -488,7 +488,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_dynamic_naive_tensor_descriptor_packed_v2( make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{})); make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_blk_buf_; StaticBuffer<AddressSpace::Vgpr, FloatAB, BlkSize> c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) { static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) { static_for<0, NRepeat, 1>{}([&](auto nr_i) {
...@@ -498,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -498,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(mr_i, nr_i, xdlops_i))>{}]; make_tuple(mr_i, nr_i, xdlops_i))>{}];
static_for<0, BlkSize, 1>{}([&](auto j) { static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(j) = c_blk.template AsType<float>()[Number< c_blk_buf_(j) = c_blk.template AsType<FloatAB>()[Number<
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}]; c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
}); });
......
...@@ -10,19 +10,19 @@ namespace ck { ...@@ -10,19 +10,19 @@ namespace ck {
enum struct mfma_instr enum struct mfma_instr
{ {
// fp32 /// fp32
mfma_f32_32x32x1xf32 = 0, mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32, mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32, mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction mfma_f32_16x16x4xf32, // k reduction
// fp16 /// fp16
mfma_f32_32x32x4f16, mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16, mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16, mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction mfma_f32_16x16x16f16, // k reduction
// bfp16 /// bfp16
mfma_f32_32x32x2bf16, mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16, mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16, mfma_f32_4x4x2bf16,
...@@ -58,7 +58,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32> ...@@ -58,7 +58,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
class FloatC> class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
} }
}; };
...@@ -87,7 +87,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32> ...@@ -87,7 +87,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
class FloatC> class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
return intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c); intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
} }
}; };
...@@ -110,17 +110,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32> ...@@ -110,17 +110,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
index_t AStride, index_t COffset,
index_t BStride,
class FloatA, class FloatA,
class FloatB, class FloatB,
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
const auto p_a = reinterpret_cast<const float*>(a); intrin_mfma_f32_16x16x4f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_16x16x4f32(p_a, p_b, reg_c);
} }
}; };
...@@ -143,17 +139,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32> ...@@ -143,17 +139,13 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
index_t AStride, index_t COffset,
index_t BStride,
class FloatA, class FloatA,
class FloatB, class FloatB,
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
const auto p_a = reinterpret_cast<const float*>(a); intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
} }
}; };
...@@ -177,17 +169,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32> ...@@ -177,17 +169,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
index_t AStride, index_t COffset,
index_t BStride,
class FloatA, class FloatA,
class FloatB, class FloatB,
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
const auto p_a = reinterpret_cast<const float*>(a); intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
} }
}; };
...@@ -523,20 +511,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16> ...@@ -523,20 +511,13 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
} }
}; };
template <mfma_instr instr, template <mfma_instr instr, index_t MPerXdlops_, index_t NPerXdlops_>
index_t MPerXdlops_,
index_t NPerXdlops_,
index_t MRepeats_,
index_t NRepeats_,
class CType_>
struct xdlops_info struct xdlops_info
{ {
static constexpr auto mfma_type = mfma_info<instr>{}; static constexpr auto mfma_type = mfma_info<instr>{};
static constexpr index_t MPerXdlops = MPerXdlops_; static constexpr index_t MPerXdlops = MPerXdlops_;
static constexpr index_t NPerXdlops = NPerXdlops_; static constexpr index_t NPerXdlops = NPerXdlops_;
static constexpr index_t MRepeats = MRepeats_;
static constexpr index_t NRepeats = NRepeats_;
static constexpr bool IsABroadcast() static constexpr bool IsABroadcast()
{ {
...@@ -555,8 +536,6 @@ struct xdlops_info ...@@ -555,8 +536,6 @@ struct xdlops_info
} }
static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; } static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; }
static constexpr auto GetCType() { return CType_{}; }
}; };
template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack> template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack>
...@@ -570,55 +549,43 @@ struct XdlopsGemm ...@@ -570,55 +549,43 @@ struct XdlopsGemm
template <> template <>
static constexpr auto GetXdlopsInfo<float, 64, 64>() static constexpr auto GetXdlopsInfo<float, 64, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64, 1, 1, float64_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 32, 1, 1, float32_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 32, 64>() static constexpr auto GetXdlopsInfo<float, 32, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64, 1, 1, float32_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 64, 16, 1, 1, float16_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 16, 64>() static constexpr auto GetXdlopsInfo<float, 16, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64, 1, 1, float16_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 8, 64>() static constexpr auto GetXdlopsInfo<float, 8, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64, 1, 1, float8_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 4, 64>() static constexpr auto GetXdlopsInfo<float, 4, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64, 1, 1, float4_t>{}; return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 32, 32>() static constexpr auto GetXdlopsInfo<float, 32, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32, 1, 1, float16_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 16, 16>() static constexpr auto GetXdlopsInfo<float, 16, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16, 1, 1, float4_t>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16>{};
} }
#if 0 #if 0
......
...@@ -201,42 +201,6 @@ extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( ...@@ -201,42 +201,6 @@ extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x1f32; struct intrin_mfma_f32_32x32x1f32;
// template <index_t AStride, index_t BStride>
// struct intrin_mfma_f32_32x32x1f32<128, 64, AStride, BStride>
//{
//__device__ static c_vec32_4_t::VecType
// run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
// reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
// reg_c.s.z =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
// reg_c.s.w =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
// return reg_c;
//}
//};
// template <index_t AStride, index_t BStride>
// struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride>
//{
//__device__ static c_vec32_4_t::VecType
// run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
// reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
// reg_c.s.z =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
// reg_c.s.w =
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
// return reg_c;
//}
//};
template <index_t COffset> template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
{ {
...@@ -262,27 +226,22 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> ...@@ -262,27 +226,22 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
} }
}; };
// template <index_t AStride, index_t BStride> template <index_t COffset>
// struct intrin_mfma_f32_32x32x1f32<64, 32, AStride, BStride> struct intrin_mfma_f32_32x32x1f32<32, 64, COffset>
//{ {
//__device__ static c_vec32_1_t::VecType template <class FloatA, class FloatB, class FloatC>
// run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c) __device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
//{ {
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1); reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
// return reg_c; llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
//} reg_a,
//}; reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
// template <index_t AStride, index_t BStride> 1,
// struct intrin_mfma_f32_32x32x1f32<32, 64, AStride, BStride> 0,
//{ 0);
//__device__ static c_vec32_1_t::VecType }
// run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c) };
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
// return reg_c;
//}
//};
template <index_t MPerWave, index_t NPerWave, index_t COffset> template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x2f32; struct intrin_mfma_f32_32x32x2f32;
...@@ -304,58 +263,89 @@ struct intrin_mfma_f32_32x32x2f32<32, 32, COffset> ...@@ -304,58 +263,89 @@ struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
} }
}; };
__device__ c_vec4_1_t::VecType template <index_t MPerWave, index_t NPerWave, index_t COffset>
intrin_mfma_f32_16x16x4f32(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c) struct intrin_mfma_f32_16x16x4f32;
template <index_t COffset>
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset>
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x4f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); template <class FloatA, class FloatB, class FloatC>
return reg_c; __device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
} {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave, index_t COffset>
__device__ c_vec16_1_t::VecType struct intrin_mfma_f32_16x16x1f32;
intrin_mfma_f32_16x16x1f32(const float* reg_a, const float* reg_b, c_vec16_1_t::VecType reg_c);
template <> template <index_t COffset>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x1f32<16, 64>(const float* reg_a, struct intrin_mfma_f32_16x16x1f32<16, 64, COffset>
const float* reg_b,
c_vec16_1_t::VecType reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0); template <class FloatA, class FloatB, class FloatC>
return reg_c; __device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
} {
template <> reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x1f32<64, 16>(const float* reg_a, llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
const float* reg_b, reg_a,
c_vec16_1_t::VecType reg_c) reg_b,
{ reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4); 2,
return reg_c; 0,
} 0);
}
};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_4x4x1f32; struct intrin_mfma_f32_4x4x1f32;
template <> template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<4, 64> struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
{ {
__device__ static c_vec4_1_t::VecType template <class FloatA, class FloatB, class FloatC>
run(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c) __device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
return reg_c; llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
} }
}; };
template <> template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<8, 64> struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
{ {
__device__ static c_vec4_2_t::VecType template <class FloatA, class FloatB, class FloatC>
run(const float* reg_a, const float* reg_b, c_vec4_2_t::VecType reg_c) __device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0); llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
return reg_c; reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
} }
}; };
......
...@@ -49,6 +49,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -49,6 +49,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
static_assert(1 == InWeiVectorSize, "support InWeiVectorSize == 1 only!");
#if 1 #if 1
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc =
...@@ -108,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -108,12 +110,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 32; constexpr index_t GemmMPerWave = 8;
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPerWave = 4; constexpr index_t GemmKPerWave = 4;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 8;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
...@@ -131,7 +133,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -131,7 +133,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
#endif #endif
const auto descs = const auto descs =
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<GemmMPerBlock, transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<TInWei,
GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmMPerWave, GemmMPerWave,
GemmNPerWave, GemmNPerWave,
...@@ -148,7 +151,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -148,7 +151,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
{ {
float ave_time = launch_kernel_dynamic_gemm_xdlops_v1< float ave_time = launch_kernel_dynamic_gemm_xdlops_v1<
BlockSize, BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
...@@ -188,10 +191,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -188,10 +191,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
decltype(descs[I5]), decltype(descs[I5]),
decltype(descs[I6]), decltype(descs[I6]),
decltype(descs[I7]), decltype(descs[I7]),
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( decltype(descs[I8])>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0], descs[I0],
descs[I1], descs[I1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment