Commit 22f006cd authored by Paul's avatar Paul
Browse files

Formatting

parent 8e68f6ca
......@@ -161,7 +161,7 @@ void nary_double_broadcast_vec_impl(
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b1 = bp[bidx];
auto b2 = bp[bidx+bdim_vec_len];
auto b2 = bp[bidx + bdim_vec_len];
auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++)
{
......@@ -278,7 +278,8 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
}
template <class... Arguments>
bool broadcastable(bool& divisible_by_4, std::size_t max_size, argument result, argument barg, Arguments... args)
bool broadcastable(
bool& divisible_by_4, std::size_t max_size, argument result, argument barg, Arguments... args)
{
divisible_by_4 = false;
auto bshape = barg.get_shape();
......@@ -351,19 +352,22 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
auto barg1 = back_args(args...);
bool fallback1 = pop_back_args(args...)([&](auto&&... args2) {
auto barg2 = back_args(args2...);
bool fallback2 = barg2.get_shape() == barg1.get_shape() and barg2.get_shape().broadcasted() and pop_back_args(args2...)([&](auto&&... args3) {
bool fallback2 =
barg2.get_shape() == barg1.get_shape() and barg2.get_shape().broadcasted() and
pop_back_args(args2...)([&](auto&&... args3) {
bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 1024, result, barg2, args3...))
{
if(divisible_by_4)
nary_double_broadcast_vec_impl(stream, f, result, barg1, barg2, args3...);
nary_double_broadcast_vec_impl(
stream, f, result, barg1, barg2, args3...);
else
nary_double_broadcast_impl(stream, f, result, barg1, barg2, args3...);
return false;
}
return true;
});
if (not fallback2)
if(not fallback2)
return false;
bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 2048, result, barg1, args2...))
......
......@@ -224,8 +224,8 @@ void move_broadcasted_back(std::vector<instruction_ref>& args)
{
// Ensure the last arguments is the broadcasted one
auto last = std::prev(args.end());
auto it = std::find_if(
args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
auto it =
std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != last)
std::swap(*it, *std::prev(last));
}
......@@ -234,8 +234,8 @@ void move_standard_front(std::vector<instruction_ref>& args)
{
// Ensure the first arguments is the standard one
auto last = std::prev(args.end());
auto it = std::find_if(
args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
auto it =
std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
if(it != last)
std::swap(*it, args.front());
}
......@@ -304,9 +304,8 @@ struct find_mul_add
{
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::mul").bind("mul"),
match::any().bind("b")));
return match::name("gpu::add")(
match::either_arg(0, 1)(match::name("gpu::mul").bind("mul"), match::any().bind("b")));
}
void apply(program& p, match::matcher_result r) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment