Commit f5ae909b authored by Jing Zhang's avatar Jing Zhang
Browse files

add bfl16 buildins

parent 41fb383f
...@@ -5,45 +5,6 @@ ...@@ -5,45 +5,6 @@
namespace ck { namespace ck {
// A, B, C, cbsz, abid, blgp
// bfp16
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
ushort4_t, ushort4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16bf16.1k");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
// int8
extern "C" __device__ int32x32_t llvm_intrin_amdgcn_mfma_i32_32x32x4i8(
int, int, int32x32_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x4i8");
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_16x16x4i8(
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x4i8");
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_4x4x4i8(
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.4x4x4i8");
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_32x32x8i8(
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x8i8");
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_16x16x16i8(
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x16i8");
// fp32 // fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x1f32; struct intrin_mfma_f32_32x32x1f32;
...@@ -248,9 +209,8 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32> ...@@ -248,9 +209,8 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k( reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -263,9 +223,8 @@ struct intrin_mfma_f32_16x16x16bf16_1k<16, 16> ...@@ -263,9 +223,8 @@ struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -278,7 +237,7 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32> ...@@ -278,7 +237,7 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32>
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -292,7 +251,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16> ...@@ -292,7 +251,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "conv_common.hpp" #include "conv_common.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" //#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" //#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" //#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
...@@ -250,15 +250,15 @@ int main(int argc, char* argv[]) ...@@ -250,15 +250,15 @@ int main(int argc, char* argv[])
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
#endif #endif
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = float; using out_data_t = float;
#elif 1 #elif 0
using in_data_t = half_t; using in_data_t = half_t;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = half_t; using out_data_t = half_t;
#elif 0 #elif 1
using in_data_t = ushort; using in_data_t = ushort;
using acc_data_t = float; using acc_data_t = float;
using out_data_t = ushort; using out_data_t = ushort;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment