Unverified Commit 5dfafd00 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix vectorization of broadcasted inputs in pointwise fusions (#1011)

parent 2d4dcc47
......@@ -75,7 +75,7 @@ constexpr index_int find_vector_axis(Shapes... ss)
index_int axis = 0;
bool b = false;
by([&](auto s) {
if(s.broadcasted() or b)
if(b)
return;
auto it = find(s.strides.begin(), s.strides.end(), 1);
if(it == s.strides.end())
......@@ -89,14 +89,17 @@ constexpr index_int find_vector_axis(Shapes... ss)
template <index_int N, class Axis, class... Shapes>
constexpr auto is_vectorizable(Axis axis, Shapes... ss)
{
return (((ss.lens[axis] % N) == 0 and (ss.strides[axis] == 1 or ss.strides[axis] == 0)) and
...);
return (((ss.lens[axis] % N) == 0 and ss.strides[axis] == 1) and ...);
}
template <index_int N, class... Shapes>
constexpr bool is_vectorizable(Shapes... ss)
template <index_int N, class Shape>
constexpr bool is_vectorizable(Shape s)
{
return (is_vectorizable<N>(ss, find_vector_axis(ss)) and ...);
auto it = find(s.strides.begin(), s.strides.end(), 1);
if(it == s.strides.end())
return false;
auto axis = it - s.strides.begin();
return (s.lens[axis] % N) == 0 and s.strides[axis] == 1;
}
template <class P>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment