Unverified Commit 5dfafd00 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix vectorization of broadcasted inputs in pointwise fusions (#1011)

parent 2d4dcc47
...@@ -75,7 +75,7 @@ constexpr index_int find_vector_axis(Shapes... ss) ...@@ -75,7 +75,7 @@ constexpr index_int find_vector_axis(Shapes... ss)
index_int axis = 0; index_int axis = 0;
bool b = false; bool b = false;
by([&](auto s) { by([&](auto s) {
if(s.broadcasted() or b) if(b)
return; return;
auto it = find(s.strides.begin(), s.strides.end(), 1); auto it = find(s.strides.begin(), s.strides.end(), 1);
if(it == s.strides.end()) if(it == s.strides.end())
...@@ -89,14 +89,17 @@ constexpr index_int find_vector_axis(Shapes... ss) ...@@ -89,14 +89,17 @@ constexpr index_int find_vector_axis(Shapes... ss)
template <index_int N, class Axis, class... Shapes> template <index_int N, class Axis, class... Shapes>
constexpr auto is_vectorizable(Axis axis, Shapes... ss) constexpr auto is_vectorizable(Axis axis, Shapes... ss)
{ {
return (((ss.lens[axis] % N) == 0 and (ss.strides[axis] == 1 or ss.strides[axis] == 0)) and return (((ss.lens[axis] % N) == 0 and ss.strides[axis] == 1) and ...);
...);
} }
template <index_int N, class... Shapes> template <index_int N, class Shape>
constexpr bool is_vectorizable(Shapes... ss) constexpr bool is_vectorizable(Shape s)
{ {
return (is_vectorizable<N>(ss, find_vector_axis(ss)) and ...); auto it = find(s.strides.begin(), s.strides.end(), 1);
if(it == s.strides.end())
return false;
auto axis = it - s.strides.begin();
return (s.lens[axis] % N) == 0 and s.strides[axis] == 1;
} }
template <class P> template <class P>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment