Commit 8a9daf16 authored by Bartlomiej's avatar Bartlomiej Committed by Bartlomiej Kocot
Browse files

Fix builtin for inner_produxt fp16

parent df4cc03f
...@@ -75,9 +75,8 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f ...@@ -75,9 +75,8 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
template <> template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c) __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{ {
// builtin is disabled because it does not generate s_nop #if defined(CK_USE_AMD_V_DOT2_F32_F16)
// and this can lead to hazards #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
#if defined(CK_USE_AMD_V_DOT2_F32_F16) && CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
// Use 3 x s_nop to avoid hazard (mi200 cdna2 isa) // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa)
asm volatile("\n \ asm volatile("\n \
v_dot2_f32_f16 %0, %1, %2, %0\n \ v_dot2_f32_f16 %0, %1, %2, %0\n \
...@@ -85,7 +84,9 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h ...@@ -85,7 +84,9 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
" "
: "=v"(c) : "=v"(c)
: "v"(a), "v"(b), "0"(c)); : "v"(a), "v"(b), "0"(c));
c = __builtin_amdgcn_sdot2(a, b, c, false); #else
c = __builtin_amdgcn_fdot2(a, b, c, false);
#endif
#else #else
const vector_type<half_t, 2> a_vector{a}; const vector_type<half_t, 2> a_vector{a};
const vector_type<half_t, 2> b_vector{b}; const vector_type<half_t, 2> b_vector{b};
...@@ -163,9 +164,8 @@ template <> ...@@ -163,9 +164,8 @@ template <>
__device__ void __device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c) inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{ {
// builtin is disabled because it does not generate s_nop #if defined(CK_USE_AMD_V_DOT4_I32_I8)
// and this can lead to hazards #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
#if defined(CK_USE_AMD_V_DOT4_I32_I8) && CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
// Use 3 x s_nop to avoid hazard (mi200 cdna2 isa) // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa)
asm volatile("\n \ asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \
...@@ -173,6 +173,9 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, ...@@ -173,6 +173,9 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
" "
: "=v"(c) : "=v"(c)
: "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c)); : "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b), c, false);
#endif
#else #else
const vector_type<int8_t, 4> a_vector{a}; const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b}; const vector_type<int8_t, 4> b_vector{b};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment