Commit 6d84eb12 authored by Paul's avatar Paul
Browse files

Fix triadd bug

parent 98cfc027
...@@ -279,6 +279,15 @@ MIGRAPH_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins) ...@@ -279,6 +279,15 @@ MIGRAPH_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
return ctx.not_found(); return ctx.not_found();
} }
MIGRAPH_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins;
if(ins->outputs().empty() and std::next(ins) == ctx.not_found())
return ins;
return ctx.not_found();
}
inline auto name(std::string name) inline auto name(std::string name)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
......
...@@ -121,11 +121,11 @@ void trinary_broadcast_impl( ...@@ -121,11 +121,11 @@ void trinary_broadcast_impl(
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = input1.data(); auto* xp = input1.data();
auto* yp = input2.data(); auto* yp = input2.data();
auto* zp = input2.data(); auto* zp = input3.data();
auto* outp = output.data(); auto* outp = output.data();
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
......
...@@ -185,7 +185,7 @@ struct match_add_relu ...@@ -185,7 +185,7 @@ struct match_add_relu
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::relu")(match::arg(0)( return match::name("gpu::relu")(match::used_once(), match::arg(0)(
match::any_of(match::name("gpu::add"), match::name("hip::triadd")).bind("add"))); match::any_of(match::name("gpu::add"), match::name("hip::triadd")).bind("add")));
} }
...@@ -314,9 +314,9 @@ template <class... Ms> ...@@ -314,9 +314,9 @@ template <class... Ms>
auto conv_bias(Ms... ms) auto conv_bias(Ms... ms)
{ {
return match::name("gpu::add")( return match::name("gpu::add")(
match::either_arg(0, 1)(bias_shape(match::output()).bind("bias"), match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::output()).bind("conv")), fusable_conv(match::used_once()).bind("conv")),
match::output(), match::used_once(),
ms...); ms...);
} }
...@@ -356,7 +356,7 @@ struct match_conv_bias ...@@ -356,7 +356,7 @@ struct match_conv_bias
struct match_conv_bias_relu struct match_conv_bias_relu
{ {
context* ctx = nullptr; context* ctx = nullptr;
auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); } auto matcher() const { return match::name("gpu::relu")(match::used_once(), match::arg(0)(conv_bias())); }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
{ {
......
...@@ -174,6 +174,38 @@ struct test_add ...@@ -174,6 +174,38 @@ struct test_add
} }
}; };
struct test_triadd
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", s);
auto sum = p.add_instruction(migraph::op::add{}, x, y);
p.add_instruction(migraph::op::add{}, sum, z);
return p;
}
};
struct test_triadd2
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {2, 3}};
migraph::shape b{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraph::op::broadcast{1, s}, z);
auto sum = p.add_instruction(migraph::op::add{}, x, y);
p.add_instruction(migraph::op::add{}, sum, zb);
return p;
}
};
struct test_add_broadcast struct test_add_broadcast
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -244,6 +276,22 @@ struct test_add_broadcast5 ...@@ -244,6 +276,22 @@ struct test_add_broadcast5
} }
}; };
struct test_triadd_broadcast
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {2, 2}});
auto z = p.add_parameter("z", {migraph::shape::float_type, {2, 2, 3}});
auto by = p.add_instruction(migraph::op::broadcast{0, x->get_shape()}, y);
auto sum = p.add_instruction(migraph::op::add{}, x, by);
p.add_instruction(migraph::op::add{}, sum, z);
return p;
}
};
struct test_softmax struct test_softmax
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -557,11 +605,14 @@ struct test_conv_bn_relu_pooling2 ...@@ -557,11 +605,14 @@ struct test_conv_bn_relu_pooling2
int main() int main()
{ {
verify_program<test_add>(); verify_program<test_add>();
verify_program<test_triadd>();
verify_program<test_triadd2>();
verify_program<test_add_broadcast>(); verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>(); verify_program<test_add_broadcast2>();
verify_program<test_add_broadcast3>(); verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>(); verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>(); verify_program<test_add_broadcast5>();
verify_program<test_triadd_broadcast>();
verify_program<test_softmax>(); verify_program<test_softmax>();
verify_program<test_softmax2>(); verify_program<test_softmax2>();
verify_program<test_conv>(); verify_program<test_conv>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment