Commit 4cc472c0 authored by Paul's avatar Paul
Browse files

Formatting

parent 0c1ff20d
...@@ -330,7 +330,10 @@ inline auto outputs() ...@@ -330,7 +330,10 @@ inline auto outputs()
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; } MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; } MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); } MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins) { return not ins->get_shape().standard(); } MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
return not ins->get_shape().standard();
}
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins) MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{ {
return ins->get_shape().broadcasted(); return ins->get_shape().broadcasted();
......
...@@ -203,7 +203,8 @@ struct hip_add_relu ...@@ -203,7 +203,8 @@ struct hip_add_relu
void move_broadcasted_back(std::vector<instruction_ref>& args) void move_broadcasted_back(std::vector<instruction_ref>& args)
{ {
// Ensure the last arguments is the broadcasted one // Ensure the last arguments is the broadcasted one
auto it = std::find_if(args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); }); auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != args.end()) if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2)); std::swap(*it, *std::prev(args.end(), 2));
} }
...@@ -211,7 +212,8 @@ void move_broadcasted_back(std::vector<instruction_ref>& args) ...@@ -211,7 +212,8 @@ void move_broadcasted_back(std::vector<instruction_ref>& args)
void move_standard_front(std::vector<instruction_ref>& args) void move_standard_front(std::vector<instruction_ref>& args)
{ {
// Ensure the first arguments is the standard one // Ensure the first arguments is the standard one
auto it = std::find_if(args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); }); auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); });
if(it != args.end()) if(it != args.end())
std::swap(*it, args.front()); std::swap(*it, args.front());
} }
...@@ -220,8 +222,11 @@ struct find_add_relu ...@@ -220,8 +222,11 @@ struct find_add_relu
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::relu")(match::arg(0)( return match::name("gpu::relu")(
match::any_of(match::name("gpu::add"), match::name("hip::triadd"), match::any_of[match::inputs()](match::standard_shape())).bind("add"))); match::arg(0)(match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
match::any_of[match::inputs()](match::standard_shape()))
.bind("add")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -245,7 +250,8 @@ struct find_triadd ...@@ -245,7 +250,8 @@ struct find_triadd
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")(match::either_arg(0, 1)(match::name("gpu::add").bind("add"), return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::add").bind("add"),
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input"))); match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input")));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment