"docs/source/en/vscode:/vscode.git/clone" did not exist on "98730c5dd7d572cb9b7435afe0215247663362ba"
Commit e4026cb5 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

add bhalf vector instances to inner_product.hpp

parent 1fefd82e
...@@ -84,6 +84,18 @@ __device__ void inner_product<half_t, half_t, float>(const half_t& a, const half ...@@ -84,6 +84,18 @@ __device__ void inner_product<half_t, half_t, float>(const half_t& a, const half
inner_product(type_convert<float>(a), type_convert<float>(b), c); inner_product(type_convert<float>(a), type_convert<float>(b), c);
} }
template <>
__device__ void inner_product<bhalf2_t, bhalf2_t, float>(const bhalf2_t& a, const bhalf2_t& b, float& c)
{
const vector_type<bhalf_t, 2> a_vector{a};
const vector_type<bhalf_t, 2> b_vector{b};
static_for<0, 2, 1>{}([&](auto i) {
c += type_convert<float>(a_vector.AsType<bhalf_t>()[i]) *
type_convert<float>(b_vector.AsType<bhalf_t>()[i]);
});
}
template <> template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c) __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{ {
...@@ -112,6 +124,19 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h ...@@ -112,6 +124,19 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
#endif #endif
} }
template <>
__device__ void inner_product<bhalf4_t, bhalf4_t, float>(const bhalf4_t& a, const bhalf4_t& b, float& c)
{
const vector_type<bhalf_t, 4> a_vector{a};
const vector_type<bhalf_t, 4> b_vector{b};
static_for<0, 2, 1>{}([&](auto i) {
inner_product(a_vector.AsType<half2_t>()[i],
b_vector.AsType<half2_t>()[i],
c);
});
}
template <> template <>
__device__ void inner_product<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c) __device__ void inner_product<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
{ {
...@@ -127,6 +152,19 @@ __device__ void inner_product<half4_t, half4_t, float>(const half4_t& a, const h ...@@ -127,6 +152,19 @@ __device__ void inner_product<half4_t, half4_t, float>(const half4_t& a, const h
c); c);
} }
template <>
__device__ void inner_product<bhalf8_t, bhalf8_t, float>(const bhalf8_t& a, const bhalf8_t& b, float& c)
{
const vector_type<bhalf_t, 8> a_vector{a};
const vector_type<bhalf_t, 8> b_vector{b};
static_for<0, 4, 1>{}([&](auto i) {
inner_product(a_vector.AsType<half2_t>()[i],
b_vector.AsType<half2_t>()[i],
c);
});
}
template <> template <>
__device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c) __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment