Commit 4cc472c0 authored by Paul's avatar Paul
Browse files

Formatting

parent 0c1ff20d
......@@ -330,7 +330,10 @@ inline auto outputs()
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins) { return not ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
return not ins->get_shape().standard();
}
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
......
......@@ -203,7 +203,8 @@ struct hip_add_relu
void move_broadcasted_back(std::vector<instruction_ref>& args)
{
// Ensure the last arguments is the broadcasted one
auto it = std::find_if(args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); });
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
}
......@@ -211,7 +212,8 @@ void move_broadcasted_back(std::vector<instruction_ref>& args)
void move_standard_front(std::vector<instruction_ref>& args)
{
// Ensure the first arguments is the standard one
auto it = std::find_if(args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); });
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); });
if(it != args.end())
std::swap(*it, args.front());
}
......@@ -220,8 +222,11 @@ struct find_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(match::arg(0)(
match::any_of(match::name("gpu::add"), match::name("hip::triadd"), match::any_of[match::inputs()](match::standard_shape())).bind("add")));
return match::name("gpu::relu")(
match::arg(0)(match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
match::any_of[match::inputs()](match::standard_shape()))
.bind("add")));
}
void apply(program& p, match::matcher_result r) const
......@@ -245,7 +250,8 @@ struct find_triadd
{
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(match::name("gpu::add").bind("add"),
return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::add").bind("add"),
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input")));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment