Commit 26d6f15f authored by Paul's avatar Paul
Browse files

Dont run kernel twice for broadcast

parent 524baf05
......@@ -185,8 +185,8 @@ auto pop_back_args(Ts&&... xs)
return [&](auto f) {
using tuple_type = std::tuple<Ts&&...>;
auto t = tuple_type(static_cast<Ts&&>(xs)...);
sequence_c<sizeof...(Ts) - 1>(
[&](auto... is) { f(std::get<is>(static_cast<tuple_type&&>(t))...); });
return sequence_c<sizeof...(Ts) - 1>(
[&](auto... is) { return f(std::get<is>(static_cast<tuple_type&&>(t))...); });
};
}
......
......@@ -399,7 +399,7 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
return [=](auto f) {
auto barg = back_args(args...);
pop_back_args(args...)([&](auto&&... args2) {
bool fallback = pop_back_args(args...)([&](auto&&... args2) {
auto bshape = barg.get_shape();
const bool standard =
all_of({args2.get_shape()...}, [](const shape& s) { return s.standard(); });
......@@ -425,11 +425,13 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
nary_broadcast_vec_impl(stream, f, result, barg, args2...);
else
nary_broadcast_impl(stream, f, result, barg, args2...);
return;
return false;
}
}
return true;
});
nary_impl(stream, f, result, args...);
if (fallback)
nary_impl(stream, f, result, args...);
};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment