Commit 3bbd5988 authored by Jing Zhang's avatar Jing Zhang
Browse files

adding fp16 mfma

parent 5c27dcd5
...@@ -327,7 +327,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -327,7 +327,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(Number<NumBlks>{}, Number<BlkSize>{})); make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
StaticBuffer<AddressSpace::Vgpr, StaticBuffer<AddressSpace::Vgpr,
vector_type<FloatAB, c_blk_nb_bs_desc.GetElementSpaceSize()>, vector_type<FloatAcc, c_blk_nb_bs_desc.GetElementSpaceSize()>,
c_mr_nr_nx_desc.GetElementSpaceSize()> c_mr_nr_nx_desc.GetElementSpaceSize()>
c_thread_buf; c_thread_buf;
...@@ -488,7 +488,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -488,7 +488,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_dynamic_naive_tensor_descriptor_packed_v2( make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{})); make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
StaticBuffer<AddressSpace::Vgpr, FloatAB, BlkSize> c_blk_buf_; StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) { static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) { static_for<0, NRepeat, 1>{}([&](auto nr_i) {
...@@ -498,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -498,7 +498,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(mr_i, nr_i, xdlops_i))>{}]; make_tuple(mr_i, nr_i, xdlops_i))>{}];
static_for<0, BlkSize, 1>{}([&](auto j) { static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(j) = c_blk.template AsType<FloatAB>()[Number< c_blk_buf_(j) = c_blk.template AsType<FloatAcc>()[Number<
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}]; c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
}); });
...@@ -518,7 +518,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -518,7 +518,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
CGlobalIteratorHacks{}; CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc, FloatC,
FloatC, FloatC,
decltype(c_m0_m1_m2_n_thread_desc), decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc), decltype(c_m0_m1_m2_n_global_desc),
......
...@@ -198,18 +198,13 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16> ...@@ -198,18 +198,13 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
index_t AStride, index_t COffset,
index_t BStride,
class FloatA, class FloatA,
class FloatB, class FloatB,
class FloatC> class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
const auto p_a = reinterpret_cast<const half4_t*>(a); intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
const auto p_b = reinterpret_cast<const half4_t*>(b);
return intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
} }
}; };
...@@ -588,25 +583,13 @@ struct XdlopsGemm ...@@ -588,25 +583,13 @@ struct XdlopsGemm
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16>{}; return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16>{};
} }
#if 0
template <>
static constexpr auto GetXdlopsInfo<half_t, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 64>() static constexpr auto GetXdlopsInfo<half_t, 64, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64, 1, 1, c_vec32_2_t>{}; return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64>{};
} }
#if 0
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 32>() static constexpr auto GetXdlopsInfo<half_t, 64, 32>()
{ {
......
...@@ -204,8 +204,8 @@ struct intrin_mfma_f32_32x32x1f32; ...@@ -204,8 +204,8 @@ struct intrin_mfma_f32_32x32x1f32;
template <index_t COffset> template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
{ {
template <class FloatA, class FloatB, class FloatC> template <class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) = reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32( llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
...@@ -229,8 +229,8 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> ...@@ -229,8 +229,8 @@ struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
template <index_t COffset> template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset> struct intrin_mfma_f32_32x32x1f32<32, 64, COffset>
{ {
template <class FloatA, class FloatB, class FloatC> template <class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) = reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32( llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
...@@ -249,8 +249,8 @@ struct intrin_mfma_f32_32x32x2f32; ...@@ -249,8 +249,8 @@ struct intrin_mfma_f32_32x32x2f32;
template <index_t COffset> template <index_t COffset>
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset> struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
{ {
template <class FloatA, class FloatB, class FloatC> template <class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) = reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x2f32( llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
...@@ -269,8 +269,8 @@ struct intrin_mfma_f32_16x16x4f32; ...@@ -269,8 +269,8 @@ struct intrin_mfma_f32_16x16x4f32;
template <index_t COffset> template <index_t COffset>
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset> struct intrin_mfma_f32_16x16x4f32<16, 16, COffset>
{ {
template <class FloatA, class FloatB, class FloatC> template <class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f32( llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
...@@ -289,8 +289,8 @@ struct intrin_mfma_f32_16x16x1f32; ...@@ -289,8 +289,8 @@ struct intrin_mfma_f32_16x16x1f32;
template <index_t COffset> template <index_t COffset>
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset> struct intrin_mfma_f32_16x16x1f32<16, 64, COffset>
{ {
template <class FloatA, class FloatB, class FloatC> template <class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) = reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
...@@ -310,8 +310,8 @@ struct intrin_mfma_f32_4x4x1f32; ...@@ -310,8 +310,8 @@ struct intrin_mfma_f32_4x4x1f32;
template <index_t COffset> template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset> struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
{ {
template <class FloatA, class FloatB, class FloatC> template <class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32( llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
...@@ -327,8 +327,8 @@ struct intrin_mfma_f32_4x4x1f32<4, 64, COffset> ...@@ -327,8 +327,8 @@ struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
template <index_t COffset> template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset> struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
{ {
template <class FloatA, class FloatB, class FloatC> template <class FloatC>
__device__ static void Run(const FloatA& reg_a, const FloatB& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) = reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32( llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
...@@ -349,78 +349,48 @@ struct intrin_mfma_f32_4x4x1f32<8, 64, COffset> ...@@ -349,78 +349,48 @@ struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
} }
}; };
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride> template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x4f16; struct intrin_mfma_f32_32x32x4f16;
template <index_t AStride, index_t BStride> template <index_t COffset>
struct intrin_mfma_f32_32x32x4f16<128, 64, AStride, BStride> struct intrin_mfma_f32_32x32x4f16<64, 64, COffset>
{
__device__ static c_vec32_4_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 128, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x4f16<64, 32, AStride, BStride>
{ {
__device__ static c_vec32_1_t::VecType template <class FloatC>
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_1_t::VecType reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1); reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
return reg_c; llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
} }
}; };
template <index_t AStride, index_t BStride> template <index_t COffset>
struct intrin_mfma_f32_32x32x4f16<32, 64, AStride, BStride> struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
{ {
__device__ static c_vec32_1_t::VecType template <class FloatC>
run(const half4_t* reg_a, const half4_t* reg_b, c_vec32_1_t::VecType reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
return reg_c; llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
} }
}; };
......
...@@ -110,11 +110,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -110,11 +110,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 8; constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPerWave = 4; constexpr index_t GemmKPerWave = 4;
constexpr index_t MRepeat = 8; constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1; constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
......
...@@ -651,6 +651,11 @@ int main(int argc, char* argv[]) ...@@ -651,6 +651,11 @@ int main(int argc, char* argv[])
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
#elif 1
using in_data_t = half_t;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using out_data_t = half_t;
#elif 0 #elif 0
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
...@@ -819,6 +824,7 @@ int main(int argc, char* argv[]) ...@@ -819,6 +824,7 @@ int main(int argc, char* argv[])
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#if 0
if(do_log) if(do_log)
{ {
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
...@@ -826,5 +832,6 @@ int main(int argc, char* argv[]) ...@@ -826,5 +832,6 @@ int main(int argc, char* argv[])
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
} }
#endif
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment