Commit 46abeca3 authored by Jing Zhang's avatar Jing Zhang
Browse files

add fp16 buildins

parent 01192e26
...@@ -6,22 +6,6 @@ ...@@ -6,22 +6,6 @@
namespace ck { namespace ck {
// A, B, C, cbsz, abid, blgp // A, B, C, cbsz, abid, blgp
// fp16
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
// bfp16 // bfp16
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k( extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k"); ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k");
...@@ -167,9 +151,9 @@ struct intrin_mfma_f32_32x32x4f16<64, 64> ...@@ -167,9 +151,9 @@ struct intrin_mfma_f32_32x32x4f16<64, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
} }
}; };
...@@ -180,7 +164,7 @@ struct intrin_mfma_f32_32x32x4f16<32, 64> ...@@ -180,7 +164,7 @@ struct intrin_mfma_f32_32x32x4f16<32, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
} }
}; };
...@@ -194,7 +178,7 @@ struct intrin_mfma_f32_32x32x8f16<32, 32> ...@@ -194,7 +178,7 @@ struct intrin_mfma_f32_32x32x8f16<32, 32>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -208,7 +192,7 @@ struct intrin_mfma_f32_16x16x16f16<16, 16> ...@@ -208,7 +192,7 @@ struct intrin_mfma_f32_16x16x16f16<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -222,7 +206,7 @@ struct intrin_mfma_f32_16x16x4f16<16, 64> ...@@ -222,7 +206,7 @@ struct intrin_mfma_f32_16x16x4f16<16, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
} }
}; };
...@@ -236,7 +220,7 @@ struct intrin_mfma_f32_4x4x4f16<4, 64> ...@@ -236,7 +220,7 @@ struct intrin_mfma_f32_4x4x4f16<4, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
} }
}; };
...@@ -247,9 +231,9 @@ struct intrin_mfma_f32_4x4x4f16<8, 64> ...@@ -247,9 +231,9 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
} }
}; };
......
...@@ -3,7 +3,7 @@ rm -f CMakeCache.txt ...@@ -3,7 +3,7 @@ rm -f CMakeCache.txt
rm -f *.cmake rm -f *.cmake
rm -rf CMakeFiles rm -rf CMakeFiles
MY_PROJECT_SOURCE=../../.. MY_PROJECT_SOURCE=../
MY_PROJECT_INSTALL=../install.dir MY_PROJECT_INSTALL=../install.dir
cmake \ cmake \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment