Commit f5ae909b authored by Jing Zhang's avatar Jing Zhang
Browse files

add bfl16 buildins

parent 41fb383f
......@@ -5,45 +5,6 @@
namespace ck {
// A, B, C, cbsz, abid, blgp
// bfp16
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
ushort4_t, ushort4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16bf16.1k");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
// int8
extern "C" __device__ int32x32_t llvm_intrin_amdgcn_mfma_i32_32x32x4i8(
int, int, int32x32_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x4i8");
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_16x16x4i8(
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x4i8");
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_4x4x4i8(
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.4x4x4i8");
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_32x32x8i8(
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x8i8");
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_16x16x16i8(
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x16i8");
// fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x1f32;
......@@ -248,8 +209,7 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
}
};
......@@ -263,8 +223,7 @@ struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
}
};
......@@ -278,7 +237,7 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32>
template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
}
};
......@@ -292,7 +251,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
}
};
......
......@@ -12,9 +12,9 @@
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
//#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
......@@ -250,15 +250,15 @@ int main(int argc, char* argv[])
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
#endif
#if 1
#if 0
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif 1
#elif 0
using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t;
#elif 0
#elif 1
using in_data_t = ushort;
using acc_data_t = float;
using out_data_t = ushort;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment