"composable_kernel/include/utility/config.hpp" did not exist on "5c7cec11159d22636dd4c1119e7e430d156a8df7"
Commit 0c1ff20d authored by Paul's avatar Paul
Browse files

Fix error with nonstandard shapes

parent 5f31ae3f
......@@ -330,6 +330,7 @@ inline auto outputs()
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins) { return not ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
......
......@@ -200,12 +200,28 @@ struct hip_add_relu
}
};
void move_broadcasted_back(std::vector<instruction_ref>& args)
{
// Ensure the last arguments is the broadcasted one
auto it = std::find_if(args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
}
void move_standard_front(std::vector<instruction_ref>& args)
{
// Ensure the first arguments is the standard one
auto it = std::find_if(args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); });
if(it != args.end())
std::swap(*it, args.front());
}
struct find_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(match::arg(0)(
match::any_of(match::name("gpu::add"), match::name("hip::triadd")).bind("add")));
match::any_of(match::name("gpu::add"), match::name("hip::triadd"), match::any_of[match::inputs()](match::standard_shape())).bind("add")));
}
void apply(program& p, match::matcher_result r) const
......@@ -213,6 +229,9 @@ struct find_add_relu
auto add_ins = r.instructions["add"];
auto ins = r.result;
auto args = add_ins->inputs();
move_standard_front(args);
move_broadcasted_back(args);
// Use the allocation from the relu operator
args.back() = ins->inputs().back();
if(add_ins->name() == "gpu::add")
......@@ -227,7 +246,7 @@ struct find_triadd
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(match::name("gpu::add").bind("add"),
match::any().bind("input")));
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input")));
}
void apply(program& p, match::matcher_result r) const
......@@ -242,10 +261,9 @@ struct find_triadd
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
return;
args.insert(args.begin(), input_ins);
// Ensure the last arguments is the broadcasted one
auto it = std::find_if(args.begin(), args.end(), is_broadcasted);
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
move_standard_front(args);
move_broadcasted_back(args);
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment