Commit 6d84eb12 authored by Paul's avatar Paul
Browse files

Fix triadd bug

parent 98cfc027
......@@ -279,6 +279,15 @@ MIGRAPH_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
return ctx.not_found();
}
MIGRAPH_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins;
if(ins->outputs().empty() and std::next(ins) == ctx.not_found())
return ins;
return ctx.not_found();
}
inline auto name(std::string name)
{
return make_basic_pred_matcher(
......
......@@ -121,11 +121,11 @@ void trinary_broadcast_impl(
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = input1.data();
auto* yp = input2.data();
auto* zp = input2.data();
auto* zp = input3.data();
auto* outp = output.data();
const std::size_t nlocal = 1024;
......
......@@ -185,7 +185,7 @@ struct match_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(match::arg(0)(
return match::name("gpu::relu")(match::used_once(), match::arg(0)(
match::any_of(match::name("gpu::add"), match::name("hip::triadd")).bind("add")));
}
......@@ -314,9 +314,9 @@ template <class... Ms>
auto conv_bias(Ms... ms)
{
return match::name("gpu::add")(
match::either_arg(0, 1)(bias_shape(match::output()).bind("bias"),
fusable_conv(match::output()).bind("conv")),
match::output(),
match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::used_once()).bind("conv")),
match::used_once(),
ms...);
}
......@@ -356,7 +356,7 @@ struct match_conv_bias
struct match_conv_bias_relu
{
context* ctx = nullptr;
auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
auto matcher() const { return match::name("gpu::relu")(match::used_once(), match::arg(0)(conv_bias())); }
void apply(program& p, match::matcher_result r) const
{
......
......@@ -174,6 +174,38 @@ struct test_add
}
};
struct test_triadd
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", s);
auto sum = p.add_instruction(migraph::op::add{}, x, y);
p.add_instruction(migraph::op::add{}, sum, z);
return p;
}
};
struct test_triadd2
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {2, 3}};
migraph::shape b{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraph::op::broadcast{1, s}, z);
auto sum = p.add_instruction(migraph::op::add{}, x, y);
p.add_instruction(migraph::op::add{}, sum, zb);
return p;
}
};
struct test_add_broadcast
{
migraph::program create_program() const
......@@ -244,6 +276,22 @@ struct test_add_broadcast5
}
};
struct test_triadd_broadcast
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 2, 3}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {2, 2}});
auto z = p.add_parameter("z", {migraph::shape::float_type, {2, 2, 3}});
auto by = p.add_instruction(migraph::op::broadcast{0, x->get_shape()}, y);
auto sum = p.add_instruction(migraph::op::add{}, x, by);
p.add_instruction(migraph::op::add{}, sum, z);
return p;
}
};
struct test_softmax
{
migraph::program create_program() const
......@@ -557,11 +605,14 @@ struct test_conv_bn_relu_pooling2
int main()
{
verify_program<test_add>();
verify_program<test_triadd>();
verify_program<test_triadd2>();
verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>();
verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>();
verify_program<test_triadd_broadcast>();
verify_program<test_softmax>();
verify_program<test_softmax2>();
verify_program<test_conv>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment