Commit 26d6f15f authored by Paul's avatar Paul
Browse files

Dont run kernel twice for broadcast

parent 524baf05
...@@ -185,8 +185,8 @@ auto pop_back_args(Ts&&... xs) ...@@ -185,8 +185,8 @@ auto pop_back_args(Ts&&... xs)
return [&](auto f) { return [&](auto f) {
using tuple_type = std::tuple<Ts&&...>; using tuple_type = std::tuple<Ts&&...>;
auto t = tuple_type(static_cast<Ts&&>(xs)...); auto t = tuple_type(static_cast<Ts&&>(xs)...);
sequence_c<sizeof...(Ts) - 1>( return sequence_c<sizeof...(Ts) - 1>(
[&](auto... is) { f(std::get<is>(static_cast<tuple_type&&>(t))...); }); [&](auto... is) { return f(std::get<is>(static_cast<tuple_type&&>(t))...); });
}; };
} }
......
...@@ -399,7 +399,7 @@ auto nary(hipStream_t stream, argument result, Arguments... args) ...@@ -399,7 +399,7 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
return [=](auto f) { return [=](auto f) {
auto barg = back_args(args...); auto barg = back_args(args...);
pop_back_args(args...)([&](auto&&... args2) { bool fallback = pop_back_args(args...)([&](auto&&... args2) {
auto bshape = barg.get_shape(); auto bshape = barg.get_shape();
const bool standard = const bool standard =
all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); }); all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); });
...@@ -425,11 +425,13 @@ auto nary(hipStream_t stream, argument result, Arguments... args) ...@@ -425,11 +425,13 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
nary_broadcast_vec_impl(stream, f, result, barg, args2...); nary_broadcast_vec_impl(stream, f, result, barg, args2...);
else else
nary_broadcast_impl(stream, f, result, barg, args2...); nary_broadcast_impl(stream, f, result, barg, args2...);
return; return false;
} }
} }
return true;
}); });
nary_impl(stream, f, result, args...); if (fallback)
nary_impl(stream, f, result, args...);
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment