Commit 3bbd5988 authored by Jing Zhang's avatar Jing Zhang
Browse files

adding fp16 mfma

parent 5c27dcd5
......@@ -327,7 +327,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
StaticBuffer<AddressSpace::Vgpr,
vector_type<FloatAB, c_blk_nb_bs_desc.GetElementSpaceSize()>,
vector_type<FloatAcc, c_blk_nb_bs_desc.GetElementSpaceSize()>,
c_mr_nr_nx_desc.GetElementSpaceSize()>
c_thread_buf;
......@@ -488,7 +488,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatAB, BlkSize> c_blk_buf_;
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
......@@ -498,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(mr_i, nr_i, xdlops_i))>{}];
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(j) = c_blk.template AsType<FloatAB>()[Number<
c_blk_buf_(j) = c_blk.template AsType<FloatAcc>()[Number<
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
});
......@@ -518,7 +518,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
......
......@@ -198,18 +198,13 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = reinterpret_cast<const half4_t*>(a);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
......@@ -588,25 +583,13 @@ struct XdlopsGemm
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16>{};
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<half_t, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 1, c_vec32_2_t>{};
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64>{};
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 32>()
{
......
......@@ -204,8 +204,8 @@ struct intrin_mfma_f32_32x32x1f32;
template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
{
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
......@@ -229,8 +229,8 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset>
{
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
......@@ -249,8 +249,8 @@ struct intrin_mfma_f32_32x32x2f32;
template <index_t COffset>
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
{
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
......@@ -269,8 +269,8 @@ struct intrin_mfma_f32_16x16x4f32;
template <index_t COffset>
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset>
{
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
......@@ -289,8 +289,8 @@ struct intrin_mfma_f32_16x16x1f32;
template <index_t COffset>
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset>
{
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
......@@ -310,8 +310,8 @@ struct intrin_mfma_f32_4x4x1f32;
template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
{
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
......@@ -327,8 +327,8 @@ struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
{
template <class FloatA, class FloatB, class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c)
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
......@@ -349,78 +349,48 @@ struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
}
};
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x4f16;
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<128, 64, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 128, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 32, AStride, BStride>
template <index_t COffset>
struct intrin_mfma_f32_32x32x4f16<64, 64, COffset>
{
__device__ static c_vec32_1_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_1_t::VecType reg_c)
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<32, 64, AStride, BStride>
template <index_t COffset>
struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
{
__device__ static c_vec32_1_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_1_t::VecType reg_c)
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
return reg_c;
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
}
};
......
......@@ -110,11 +110,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 8;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPerWave = 4;
constexpr index_t MRepeat = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
......
......@@ -78,8 +78,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
......@@ -651,6 +651,11 @@ int main(int argc, char* argv[])
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using out_data_t = float;
#elif 1
using in_data_t = half_t;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using out_data_t = half_t;
#elif 0
using in_data_t = float;
constexpr index_t in_vector_size = 1;
......@@ -819,6 +824,7 @@ int main(int argc, char* argv[])
check_error(out_nkhw_host, out_nkhw_device);
#if 0
if(do_log)
{
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
......@@ -826,5 +832,6 @@ int main(int argc, char* argv[])
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
}
#endif
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment