#ifndef CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP #include "float_type.hpp" namespace ck { // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) { #if CK_USE_AMD_V_FMAC_F32 asm volatile("\n \ v_fmac_f32 %0, %2, %3 \n \ v_fmac_f32 %1, %2, %4 \n \ " : "=v"(c0), "=v"(c1) : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); #else asm volatile("\n \ v_mac_f32 %0, %2, %3 \n \ v_mac_f32 %1, %2, %4 \n \ " : "=v"(c0), "=v"(c1) : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); #endif } // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) // c2 += inner_product(a, b2) // c3 += inner_product(a, b3) __device__ void amd_assembly_outer_product_1x4( float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) { #if CK_USE_AMD_V_FMAC_F32 asm volatile("\n \ v_fmac_f32 %0, %4, %5 \n \ v_fmac_f32 %1, %4, %6 \n \ v_fmac_f32 %2, %4, %7 \n \ v_fmac_f32 %3, %4, %8 \n \ " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); #else asm volatile("\n \ v_mac_f32 %0, %4, %5 \n \ v_mac_f32 %1, %4, %6 \n \ v_mac_f32 %2, %4, %7 \n \ v_mac_f32 %3, %4, %8 \n \ " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); #endif } // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1) { asm volatile("\n \ v_dot2_f32_f16 %0, %2, %3, %0\n \ v_dot2_f32_f16 %1, %2, %4, %1\n \ " : "=v"(c0), "=v"(c1) : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); } // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) { const half2_t* p_a_half2 = reinterpret_cast(&a); const half2_t* p_b0_half2 = reinterpret_cast(&b0); const half2_t* p_b1_half2 = reinterpret_cast(&b1); // do dot2 two times asm volatile("\n \ v_dot2_f32_f16 %0, %2, %4, %0\n \ v_dot2_f32_f16 %1, %2, %6, %1\n \ v_dot2_f32_f16 %0, %3, %5, %0\n \ v_dot2_f32_f16 %1, %3, %7, %1\n \ " : "=v"(c0), "=v"(c1) : "v"(p_a_half2[0]), "v"(p_a_half2[1]), "v"(p_b0_half2[0]), "v"(p_b0_half2[1]), "v"(p_b1_half2[0]), "v"(p_b1_half2[1]), "0"(c0), "1"(c1)); } // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) // c2 += inner_product(a, b2) // c3 += inner_product(a, b3) __device__ void amd_assembly_outer_product_1x4(half2_t a, half2_t b0, half2_t b1, half2_t b2, half2_t b3, float& c0, float& c1, float& c2, float& c3) { asm volatile("\n \ v_dot2_f32_f16 %0, %4, %5, %0\n \ v_dot2_f32_f16 %1, %4, %6, %1\n \ v_dot2_f32_f16 %2, %4, %7, %2\n \ v_dot2_f32_f16 %3, %4, %8, %3\n \ " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); } // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) // c2 += inner_product(a, b2) // c3 += inner_product(a, b3) __device__ void amd_assembly_outer_product_1x4(half4_t a, half4_t b0, half4_t b1, half4_t b2, half4_t b3, float& c0, float& c1, float& c2, float& c3) { const half2_t* p_a_half2 = reinterpret_cast(&a); const half2_t* p_b0_half2 = reinterpret_cast(&b0); const half2_t* p_b1_half2 = reinterpret_cast(&b1); const half2_t* p_b2_half2 = reinterpret_cast(&b2); const half2_t* p_b3_half2 = reinterpret_cast(&b3); // do dot2 two times asm volatile("\n \ v_dot2_f32_f16 %0, %4, %6, %0\n \ v_dot2_f32_f16 %1, %4, %8, %1\n \ v_dot2_f32_f16 %2, %4, %10, %2\n \ v_dot2_f32_f16 %3, %4, %12, %3\n \ v_dot2_f32_f16 %0, %5, %7, %0\n \ v_dot2_f32_f16 %1, %5, %9, %1\n \ v_dot2_f32_f16 %2, %5, %11, %2\n \ v_dot2_f32_f16 %3, %5, %13, %3\n \ " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "v"(p_a_half2[0]), "v"(p_a_half2[1]), "v"(p_b0_half2[0]), "v"(p_b0_half2[1]), "v"(p_b1_half2[0]), "v"(p_b1_half2[1]), "v"(p_b2_half2[0]), "v"(p_b2_half2[1]), "v"(p_b3_half2[0]), "v"(p_b3_half2[1]), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); } // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1) { #if 1 asm volatile("\n \ v_dot4_i32_i8 %0, %2, %3, %0\n \ v_dot4_i32_i8 %1, %2, %4, %1\n \ " : "=v"(c0), "=v"(c1) : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); #else c0 = __builtin_amdgcn_sdot4(a, b0, c0, false); c1 = __builtin_amdgcn_sdot4(a, b1, c1, false); #endif } // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) // c2 += inner_product(a, b2) // c3 += inner_product(a, b3) __device__ void amd_assembly_outer_product_1x4(int8x4_t a, int8x4_t b0, int8x4_t b1, int8x4_t b2, int8x4_t b3, int32_t& c0, int32_t& c1, int32_t& c2, int32_t& c3) { #if 1 asm volatile("\n \ v_dot4_i32_i8 %0, %4, %5, %0\n \ v_dot4_i32_i8 %1, %4, %6, %1\n \ v_dot4_i32_i8 %2, %4, %7, %2\n \ v_dot4_i32_i8 %3, %4, %8, %3\n \ " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); #else c0 = __builtin_amdgcn_sdot4(a, b0, c0, false); c1 = __builtin_amdgcn_sdot4(a, b1, c1, false); c2 = __builtin_amdgcn_sdot4(a, b2, c2, false); c3 = __builtin_amdgcn_sdot4(a, b3, c3, false); #endif } } // namespace ck #endif