// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" namespace ck { template __device__ void inner_product(const TA& a, const TB& b, TC& c); template <> __device__ void inner_product(const float& a, const float& b, float& c) { #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32) asm volatile("\n \ v_mac_f32 %0, %1, %2 \n \ " : "=v"(c) : "v"(a), "v"(b), "0"(c)); #elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32) asm volatile("\n \ v_fmac_f32 %0, %1, %2 \n \ " : "=v"(c) : "v"(a), "v"(b), "0"(c)); #else c += a * b; #endif } template <> __device__ void inner_product(const float2_t& a, const float2_t& b, float& c) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; inner_product(vector_type{a}.AsType()[I0], vector_type{b}.AsType()[I0], c); inner_product(vector_type{a}.AsType()[I1], vector_type{b}.AsType()[I1], c); } template <> __device__ void inner_product(const float4_t& a, const float4_t& b, float& c) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; inner_product(vector_type{a}.AsType()[I0], vector_type{b}.AsType()[I0], c); inner_product(vector_type{a}.AsType()[I1], vector_type{b}.AsType()[I1], c); inner_product(vector_type{a}.AsType()[I2], vector_type{b}.AsType()[I2], c); inner_product(vector_type{a}.AsType()[I3], vector_type{b}.AsType()[I3], c); } template <> __device__ void inner_product(const half2_t& a, const half2_t& b, float& c) { #if defined(CK_USE_AMD_V_DOT2_F32_F16) #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM asm volatile("\n \ v_dot2_f32_f16 %0, %1, %2, %0\n \ " : "=v"(c) : "v"(a), "v"(b), "0"(c)); #else c = __builtin_amdgcn_sdot2(a, b, c, false); #endif #else const vector_type a_vector{a}; const vector_type b_vector{b}; static_for<0, 2, 1>{}([&](auto i) { c += type_convert(a_vector.AsType()[i]) * type_convert(b_vector.AsType()[i]); }); #endif } template <> __device__ void inner_product(const half4_t& a, const half4_t& b, float& c) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; inner_product(vector_type{a}.AsType()[I0], vector_type{b}.AsType()[I0], c); inner_product(vector_type{a}.AsType()[I1], vector_type{b}.AsType()[I1], c); } template <> __device__ void inner_product(const half8_t& a, const half8_t& b, float& c) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; inner_product(vector_type{a}.AsType()[I0], vector_type{b}.AsType()[I0], c); inner_product(vector_type{a}.AsType()[I1], vector_type{b}.AsType()[I1], c); inner_product(vector_type{a}.AsType()[I2], vector_type{b}.AsType()[I2], c); inner_product(vector_type{a}.AsType()[I3], vector_type{b}.AsType()[I3], c); } template <> __device__ void inner_product(const int8_t& a, const int8_t& b, int32_t& c) { c += type_convert(a) * type_convert(b); } template <> __device__ void inner_product(const int8x2_t& a, const int8x2_t& b, int32_t& c) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; inner_product(vector_type{a}.AsType()[I0], vector_type{b}.AsType()[I0], c); inner_product(vector_type{a}.AsType()[I1], vector_type{b}.AsType()[I1], c); } template <> __device__ void inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) { #if defined(CK_USE_AMD_V_DOT4_I32_I8) #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM asm volatile("\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \ " : "=v"(c) : "v"(bit_cast(a)), "v"(bit_cast(b)), "0"(c)); #else c = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b), c, false); #endif #else const vector_type a_vector{a}; const vector_type b_vector{b}; static_for<0, 4, 1>{}([&](auto i) { c += type_convert(a_vector.AsType()[i]) * type_convert(b_vector.AsType()[i]); }); #endif } template <> __device__ void inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; inner_product(vector_type{a}.AsType()[I0], vector_type{b}.AsType()[I0], c); inner_product(vector_type{a}.AsType()[I1], vector_type{b}.AsType()[I1], c); } template <> __device__ void inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; inner_product(vector_type{a}.AsType()[I0], vector_type{b}.AsType()[I0], c); inner_product(vector_type{a}.AsType()[I1], vector_type{b}.AsType()[I1], c); inner_product(vector_type{a}.AsType()[I2], vector_type{b}.AsType()[I2], c); inner_product(vector_type{a}.AsType()[I3], vector_type{b}.AsType()[I3], c); } } // namespace ck