Unverified Commit 6d92959a authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Replace llvm Intrinsics with clang buildins (#65)

* test mfma builtins

* add fp16 buildins

* add int8 buildins

* add bfl16 buildins

* simplify host conv forward

* clean

* clean
parent 4be7f019
...@@ -5,77 +5,6 @@ ...@@ -5,77 +5,6 @@
namespace ck { namespace ck {
// A, B, C, cbsz, abid, blgp
// fp32
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
// fp16
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
// bfp16
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
ushort4_t, ushort4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16bf16.1k");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
// int8
extern "C" __device__ int32x32_t llvm_intrin_amdgcn_mfma_i32_32x32x4i8(
int, int, int32x32_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x4i8");
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_16x16x4i8(
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x4i8");
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_4x4x4i8(
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.4x4x4i8");
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_32x32x8i8(
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x8i8");
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_16x16x16i8(
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x16i8");
// fp32 // fp32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x1f32; struct intrin_mfma_f32_32x32x1f32;
...@@ -86,9 +15,9 @@ struct intrin_mfma_f32_32x32x1f32<64, 64> ...@@ -86,9 +15,9 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
} }
}; };
...@@ -99,7 +28,7 @@ struct intrin_mfma_f32_32x32x1f32<32, 64> ...@@ -99,7 +28,7 @@ struct intrin_mfma_f32_32x32x1f32<32, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
} }
}; };
...@@ -113,7 +42,7 @@ struct intrin_mfma_f32_32x32x2f32<32, 32> ...@@ -113,7 +42,7 @@ struct intrin_mfma_f32_32x32x2f32<32, 32>
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -127,7 +56,7 @@ struct intrin_mfma_f32_16x16x4f32<16, 16> ...@@ -127,7 +56,7 @@ struct intrin_mfma_f32_16x16x4f32<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -141,8 +70,7 @@ struct intrin_mfma_f32_16x16x1f32<16, 64> ...@@ -141,8 +70,7 @@ struct intrin_mfma_f32_16x16x1f32<16, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
} }
}; };
...@@ -156,7 +84,7 @@ struct intrin_mfma_f32_4x4x1f32<4, 64> ...@@ -156,7 +84,7 @@ struct intrin_mfma_f32_4x4x1f32<4, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
} }
}; };
...@@ -167,9 +95,9 @@ struct intrin_mfma_f32_4x4x1f32<8, 64> ...@@ -167,9 +95,9 @@ struct intrin_mfma_f32_4x4x1f32<8, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
} }
}; };
...@@ -184,9 +112,9 @@ struct intrin_mfma_f32_32x32x4f16<64, 64> ...@@ -184,9 +112,9 @@ struct intrin_mfma_f32_32x32x4f16<64, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
} }
}; };
...@@ -197,7 +125,7 @@ struct intrin_mfma_f32_32x32x4f16<32, 64> ...@@ -197,7 +125,7 @@ struct intrin_mfma_f32_32x32x4f16<32, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0); reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
} }
}; };
...@@ -211,7 +139,7 @@ struct intrin_mfma_f32_32x32x8f16<32, 32> ...@@ -211,7 +139,7 @@ struct intrin_mfma_f32_32x32x8f16<32, 32>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -225,7 +153,7 @@ struct intrin_mfma_f32_16x16x16f16<16, 16> ...@@ -225,7 +153,7 @@ struct intrin_mfma_f32_16x16x16f16<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -239,7 +167,7 @@ struct intrin_mfma_f32_16x16x4f16<16, 64> ...@@ -239,7 +167,7 @@ struct intrin_mfma_f32_16x16x4f16<16, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
} }
}; };
...@@ -253,7 +181,7 @@ struct intrin_mfma_f32_4x4x4f16<4, 64> ...@@ -253,7 +181,7 @@ struct intrin_mfma_f32_4x4x4f16<4, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
} }
}; };
...@@ -264,9 +192,9 @@ struct intrin_mfma_f32_4x4x4f16<8, 64> ...@@ -264,9 +192,9 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
} }
}; };
...@@ -281,9 +209,8 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32> ...@@ -281,9 +209,8 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k( reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -296,9 +223,8 @@ struct intrin_mfma_f32_16x16x16bf16_1k<16, 16> ...@@ -296,9 +223,8 @@ struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c) __device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k( reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -311,7 +237,7 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32> ...@@ -311,7 +237,7 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32>
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -325,7 +251,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16> ...@@ -325,7 +251,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c) __device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
...@@ -340,12 +266,12 @@ struct intrin_mfma_i32_32x32x8i8<32, 32> ...@@ -340,12 +266,12 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<int32x16_t>()(Number<0>{}) = reg_c.template AsType<int32x16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int>(reg_a), __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int>(reg_a),
bit_cast<int>(reg_b), bit_cast<int>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}], reg_c.template AsType<int32x16_t>()[Number<0>{}],
0, 0,
0, 0,
0); 0);
} }
}; };
...@@ -359,12 +285,12 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> ...@@ -359,12 +285,12 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<int32x4_t>()(Number<0>{}) = reg_c.template AsType<int32x4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int>(reg_a), __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int>(reg_a),
bit_cast<int>(reg_b), bit_cast<int>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}], reg_c.template AsType<int32x4_t>()[Number<0>{}],
0, 0,
0, 0,
0); 0);
} }
}; };
......
...@@ -169,6 +169,8 @@ struct DynamicBuffer ...@@ -169,6 +169,8 @@ struct DynamicBuffer
is_same<remove_cvref_t<X>, int8x2_t>::value) || is_same<remove_cvref_t<X>, int8x2_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value && (is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value) || is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(is_same<remove_cvref_t<T>, int8x4_t>::value && (is_same<remove_cvref_t<T>, int8x4_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value) || is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(is_same<remove_cvref_t<T>, int8x8_t>::value && (is_same<remove_cvref_t<T>, int8x8_t>::value &&
...@@ -202,6 +204,14 @@ struct DynamicBuffer ...@@ -202,6 +204,14 @@ struct DynamicBuffer
*c_style_pointer_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cvref_t<X>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value && else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
is_same<remove_cvref_t<X>, int8x4_t>::value) is_same<remove_cvref_t<X>, int8x4_t>::value)
{ {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp"
#include "element_wise_operation.hpp"
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -70,6 +71,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -70,6 +71,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
using ElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;
using GridwiseGemm = using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize, GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB, FloatAB,
...@@ -79,6 +82,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -79,6 +82,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K, BGridDesc_K0_N_K,
CMNGridDesc, CMNGridDesc,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -87,7 +93,6 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -87,7 +93,6 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
K1, K1,
MRepeat, MRepeat,
NRepeat, NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -95,7 +100,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -95,7 +100,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1, ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -103,17 +108,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -103,17 +108,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector>;
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
CAccessOrderMRepeatNRepeat,
ABlockLdsAddExtraM,
BBlockLdsAddExtraN>;
{ {
std::cout << "a_grid_desc_k0_m_k1{" << a_grid_desc_k0_m_k1.GetLength(I0) << ", " std::cout << "a_grid_desc_k0_m_k1{" << a_grid_desc_k0_m_k1.GetLength(I0) << ", "
...@@ -152,6 +150,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -152,6 +150,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
float ave_time = 0; float ave_time = 0;
auto element_op_ = ElementwiseOperation{};
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
...@@ -162,6 +162,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -162,6 +162,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t<AGridDesc_K0_M_K1>, remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K>, remove_reference_t<BGridDesc_K0_N_K>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
true>; true>;
...@@ -176,6 +179,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -176,6 +179,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
element_op_,
element_op_,
element_op_,
block_2_ctile_map); block_2_ctile_map);
} }
else else
...@@ -187,6 +193,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -187,6 +193,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t<AGridDesc_K0_M_K1>, remove_reference_t<AGridDesc_K0_M_K1>,
remove_reference_t<BGridDesc_K0_N_K>, remove_reference_t<BGridDesc_K0_N_K>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
ElementwiseOperation,
ElementwiseOperation,
ElementwiseOperation,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
false>; false>;
...@@ -201,6 +210,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -201,6 +210,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
element_op_,
element_op_,
element_op_,
block_2_ctile_map); block_2_ctile_map);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment