Commit d673e0c4 authored by Paul's avatar Paul
Browse files

Fix fusion for triadd

parent 270194c4
......@@ -136,6 +136,36 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
op.dilation == make_array<size_t>(1, 1);
}
struct hip_triadd
{
std::string name() const { return "hip::triadd"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
device::add(args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
};
struct hip_triadd_relu
{
std::string name() const { return "hip::triadd_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
device::add_relu(args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
};
struct hip_add_relu
{
std::string name() const { return "hip::add_relu"; }
......@@ -155,7 +185,7 @@ struct match_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(match::arg(0)(match::name("gpu::add").bind("add")));
return match::name("gpu::relu")(match::arg(0)(match::any_of(match::name("gpu::add"), match::name("hip::triadd")).bind("add")));
}
void apply(program& p, match::matcher_result r) const
......@@ -165,7 +195,36 @@ struct match_add_relu
auto args = add_ins->inputs();
// Use the allocation from the relu operator
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_relu{}, args);
if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, hip_add_relu{}, args);
else if(add_ins->name() == "hip::triadd")
p.replace_instruction(ins, hip_triadd_relu{}, args);
}
};
struct match_triadd
{
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(match::name("gpu::add").bind("add"), match::any().bind("input")));
}
void apply(program& p, match::matcher_result r) const
{
auto add_ins = r.instructions["add"];
auto input_ins = r.instructions["input"];
auto ins = r.result;
auto args = add_ins->inputs();
auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
return;
args.insert(args.begin(), input_ins);
// Ensure the last arguments is the broadcasted one
auto it = std::find_if(args.begin(), args.end(), is_broadcasted);
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args);
}
};
......@@ -305,8 +364,14 @@ struct match_conv_bias_relu
void fuse_ops::apply(program& p) const
{
// match::find_matches(p, match_conv_bias_relu{ctx}, match_conv_bias{ctx}, match_add_relu{});
match::find_matches(p, match_conv_bias{ctx}, match_add_relu{});
// clang-format off
match::find_matches(p,
match_triadd{},
match_conv_bias_relu{ctx},
match_conv_bias{ctx}
match_add_relu{}
);
// clang-format on
}
} // namespace gpu
......
......@@ -10,6 +10,8 @@ namespace device {
void add_relu(const argument& result, const argument& arg1, const argument& arg2);
void add_relu(const argument& result, const argument& arg1, const argument& arg2, const argument& arg3);
} // namespace device
} // namespace gpu
} // namespace migraph
......
......@@ -35,10 +35,10 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes{},
dead_code_elimination{},
lowering{ctx},
fuse_ops{&ctx},
dead_code_elimination{},
eliminate_contiguous{},
dead_code_elimination{},
fuse_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx},
memory_coloring{"hip::allocate"},
eliminate_workspace{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment