Commit 22f006cd authored by Paul's avatar Paul
Browse files

Formatting

parent 8e68f6ca
...@@ -161,7 +161,7 @@ void nary_double_broadcast_vec_impl( ...@@ -161,7 +161,7 @@ void nary_double_broadcast_vec_impl(
{ {
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride; auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b1 = bp[bidx]; auto b1 = bp[bidx];
auto b2 = bp[bidx+bdim_vec_len]; auto b2 = bp[bidx + bdim_vec_len];
auto out = output.data()[i]; auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
{ {
...@@ -278,7 +278,8 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args) ...@@ -278,7 +278,8 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
} }
template <class... Arguments> template <class... Arguments>
bool broadcastable(bool& divisible_by_4, std::size_t max_size, argument result, argument barg, Arguments... args) bool broadcastable(
bool& divisible_by_4, std::size_t max_size, argument result, argument barg, Arguments... args)
{ {
divisible_by_4 = false; divisible_by_4 = false;
auto bshape = barg.get_shape(); auto bshape = barg.get_shape();
...@@ -351,19 +352,22 @@ auto nary(hipStream_t stream, argument result, Arguments... args) ...@@ -351,19 +352,22 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
auto barg1 = back_args(args...); auto barg1 = back_args(args...);
bool fallback1 = pop_back_args(args...)([&](auto&&... args2) { bool fallback1 = pop_back_args(args...)([&](auto&&... args2) {
auto barg2 = back_args(args2...); auto barg2 = back_args(args2...);
bool fallback2 = barg2.get_shape() == barg1.get_shape() and barg2.get_shape().broadcasted() and pop_back_args(args2...)([&](auto&&... args3) { bool fallback2 =
barg2.get_shape() == barg1.get_shape() and barg2.get_shape().broadcasted() and
pop_back_args(args2...)([&](auto&&... args3) {
bool divisible_by_4 = false; bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 1024, result, barg2, args3...)) if(broadcastable(divisible_by_4, 1024, result, barg2, args3...))
{ {
if(divisible_by_4) if(divisible_by_4)
nary_double_broadcast_vec_impl(stream, f, result, barg1, barg2, args3...); nary_double_broadcast_vec_impl(
stream, f, result, barg1, barg2, args3...);
else else
nary_double_broadcast_impl(stream, f, result, barg1, barg2, args3...); nary_double_broadcast_impl(stream, f, result, barg1, barg2, args3...);
return false; return false;
} }
return true; return true;
}); });
if (not fallback2) if(not fallback2)
return false; return false;
bool divisible_by_4 = false; bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 2048, result, barg1, args2...)) if(broadcastable(divisible_by_4, 2048, result, barg1, args2...))
......
...@@ -224,8 +224,8 @@ void move_broadcasted_back(std::vector<instruction_ref>& args) ...@@ -224,8 +224,8 @@ void move_broadcasted_back(std::vector<instruction_ref>& args)
{ {
// Ensure the last arguments is the broadcasted one // Ensure the last arguments is the broadcasted one
auto last = std::prev(args.end()); auto last = std::prev(args.end());
auto it = std::find_if( auto it =
args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); }); std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != last) if(it != last)
std::swap(*it, *std::prev(last)); std::swap(*it, *std::prev(last));
} }
...@@ -234,8 +234,8 @@ void move_standard_front(std::vector<instruction_ref>& args) ...@@ -234,8 +234,8 @@ void move_standard_front(std::vector<instruction_ref>& args)
{ {
// Ensure the first arguments is the standard one // Ensure the first arguments is the standard one
auto last = std::prev(args.end()); auto last = std::prev(args.end());
auto it = std::find_if( auto it =
args.begin(), last, [](auto arg) { return arg->get_shape().standard(); }); std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
if(it != last) if(it != last)
std::swap(*it, args.front()); std::swap(*it, args.front());
} }
...@@ -304,9 +304,8 @@ struct find_mul_add ...@@ -304,9 +304,8 @@ struct find_mul_add
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")(match::either_arg(0, 1)( return match::name("gpu::add")(
match::name("gpu::mul").bind("mul"), match::either_arg(0, 1)(match::name("gpu::mul").bind("mul"), match::any().bind("b")));
match::any().bind("b")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment