Commit 9cd72cdd authored by Paul's avatar Paul
Browse files

Merge branch 'jit-layernorm-merge' into bert-opt-layernorm

parents 03c6967e 57a5c827
...@@ -75,20 +75,15 @@ constexpr auto is_vectorizable() ...@@ -75,20 +75,15 @@ constexpr auto is_vectorizable()
} }
template <class T> template <class T>
constexpr auto array2vec(T x) __device__ auto& array2vec(T& x)
{ {
using value_type = typename T::value_type; using value_type = typename T::value_type;
constexpr auto size = decltype(x.size()){}; constexpr auto size = decltype(x.size()){};
using type = vec<value_type, size>; using type = vec<value_type, size>;
static_assert(size != 3, "Wrong size"); if constexpr(is_const<T>{})
return __builtin_bit_cast(type, x); return reinterpret_cast<const type&>(x);
} else
return reinterpret_cast<type&>(x);
template <class T, class U, index_int N>
constexpr void vec2array(T& x, vec<U, N> v)
{
if constexpr(not is_const<T>{})
x = __builtin_bit_cast(T, v);
} }
template <class T, class... Ts> template <class T, class... Ts>
...@@ -101,11 +96,15 @@ constexpr auto array_for_each(T& x, Ts&... xs) ...@@ -101,11 +96,15 @@ constexpr auto array_for_each(T& x, Ts&... xs)
(is_vectorizable<typename Ts::value_type>() or ...)) and (is_vectorizable<typename Ts::value_type>() or ...)) and
size <= 8 and size > 1 and (size % 2 == 0)) size <= 8 and size > 1 and (size % 2 == 0))
{ {
[&](auto v, auto... vs) { if(__builtin_is_constant_evaluated())
f(v, vs...); {
vec2array(x, v); for(index_int i = 0; i < size; i++)
swallow{(vec2array(xs, vs), 0)...}; f(x[i], xs[i]...);
}(array2vec(x), array2vec(xs)...); }
else
{
f(array2vec(x), array2vec(xs)...);
}
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment