Commit 9ccf5e5c authored by turneram's avatar turneram
Browse files

Use common.hpp to add broadcasts

parent b07f5e4e
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp> #include <migraphx/match/gelu_erf.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -42,15 +43,11 @@ struct find_gelu_erf ...@@ -42,15 +43,11 @@ struct find_gelu_erf
return; return;
auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}}); auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}});
auto mul = m.insert_instruction( auto mul = insert_common_op(m, ins, make_op("mul"), {x, lit});
ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lit);
mul = m.insert_instruction(ins, make_op("mul"), x, mul);
auto sig = m.insert_instruction(ins, make_op("neg"), mul); auto sig = m.insert_instruction(ins, make_op("neg"), mul);
sig = m.insert_instruction(ins, make_op("exp"), sig); sig = m.insert_instruction(ins, make_op("exp"), sig);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}}); auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
one = m.insert_instruction( sig = insert_common_op(m, ins, make_op("add"), {sig, one});
ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), one);
sig = m.insert_instruction(ins, make_op("add"), sig, one);
sig = m.insert_instruction(ins, make_op("div"), x, sig); sig = m.insert_instruction(ins, make_op("div"), x, sig);
m.replace_instruction(ins, sig); m.replace_instruction(ins, sig);
} }
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <test.hpp> #include <test.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
...@@ -47,16 +48,13 @@ TEST_CASE(bias_gelu) ...@@ -47,16 +48,13 @@ TEST_CASE(bias_gelu)
auto b = m1.add_parameter("b", s1); auto b = m1.add_parameter("b", s1);
auto add1 = m1.add_instruction(migraphx::make_op("add"), a, b); auto add1 = m1.add_instruction(migraphx::make_op("add"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}}); auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
l1 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1); auto div = add_common_op(m1, migraphx::make_op("div"), {add1, l1});
auto div = m1.add_instruction(migraphx::make_op("div"), add1, l1);
auto erf = m1.add_instruction(migraphx::make_op("erf"), div); auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}}); auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}});
l2 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l2); auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2});
auto add2 = m1.add_instruction(migraphx::make_op("add"), erf, l2);
auto mul = m1.add_instruction(migraphx::make_op("mul"), add1, add2); auto mul = m1.add_instruction(migraphx::make_op("mul"), add1, add2);
auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}}); auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}});
l3 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l3); mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3});
mul = m1.add_instruction(migraphx::make_op("mul"), mul, l3);
m1.add_return({mul}); m1.add_return({mul});
} }
migraphx::rewrite_gelu pass; migraphx::rewrite_gelu pass;
...@@ -70,13 +68,11 @@ TEST_CASE(bias_gelu) ...@@ -70,13 +68,11 @@ TEST_CASE(bias_gelu)
auto b = m2.add_parameter("b", s1); auto b = m2.add_parameter("b", s1);
auto add = m2.add_instruction(migraphx::make_op("add"), a, b); auto add = m2.add_instruction(migraphx::make_op("add"), a, b);
auto l1 = m2.add_literal(migraphx::literal{s2, {1.702f}}); auto l1 = m2.add_literal(migraphx::literal{s2, {1.702f}});
l1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1); auto mul = add_common_op(m2, migraphx::make_op("mul"), {add, l1});
auto mul = m2.add_instruction(migraphx::make_op("mul"), add, l1);
auto sig = m2.add_instruction(migraphx::make_op("neg"), mul); auto sig = m2.add_instruction(migraphx::make_op("neg"), mul);
sig = m2.add_instruction(migraphx::make_op("exp"), sig); sig = m2.add_instruction(migraphx::make_op("exp"), sig);
auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}}); auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}});
l2 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l2); sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2});
sig = m2.add_instruction(migraphx::make_op("add"), sig, l2);
sig = m2.add_instruction(migraphx::make_op("div"), add, sig); sig = m2.add_instruction(migraphx::make_op("div"), add, sig);
m2.add_return({sig}); m2.add_return({sig});
} }
...@@ -94,16 +90,13 @@ TEST_CASE(non_bias_gelu) ...@@ -94,16 +90,13 @@ TEST_CASE(non_bias_gelu)
auto b = m1.add_parameter("b", s1); auto b = m1.add_parameter("b", s1);
auto sub = m1.add_instruction(migraphx::make_op("sub"), a, b); auto sub = m1.add_instruction(migraphx::make_op("sub"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}}); auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
l1 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1); auto div = add_common_op(m1, migraphx::make_op("div"), {sub, l1});
auto div = m1.add_instruction(migraphx::make_op("div"), sub, l1);
auto erf = m1.add_instruction(migraphx::make_op("erf"), div); auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}}); auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}});
l2 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l2); auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2});
auto add2 = m1.add_instruction(migraphx::make_op("add"), erf, l2);
auto mul = m1.add_instruction(migraphx::make_op("mul"), sub, add2); auto mul = m1.add_instruction(migraphx::make_op("mul"), sub, add2);
auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}}); auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}});
l3 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l3); mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3});
mul = m1.add_instruction(migraphx::make_op("mul"), mul, l3);
m1.add_return({mul}); m1.add_return({mul});
} }
migraphx::rewrite_gelu pass; migraphx::rewrite_gelu pass;
...@@ -117,13 +110,11 @@ TEST_CASE(non_bias_gelu) ...@@ -117,13 +110,11 @@ TEST_CASE(non_bias_gelu)
auto b = m2.add_parameter("b", s1); auto b = m2.add_parameter("b", s1);
auto sub = m2.add_instruction(migraphx::make_op("sub"), a, b); auto sub = m2.add_instruction(migraphx::make_op("sub"), a, b);
auto l1 = m2.add_literal(migraphx::literal{s2, {1.702f}}); auto l1 = m2.add_literal(migraphx::literal{s2, {1.702f}});
l1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1); auto mul = add_common_op(m2, migraphx::make_op("mul"), {sub, l1});
auto mul = m2.add_instruction(migraphx::make_op("mul"), sub, l1);
auto sig = m2.add_instruction(migraphx::make_op("neg"), mul); auto sig = m2.add_instruction(migraphx::make_op("neg"), mul);
sig = m2.add_instruction(migraphx::make_op("exp"), sig); sig = m2.add_instruction(migraphx::make_op("exp"), sig);
auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}}); auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}});
l2 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l2); sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2});
sig = m2.add_instruction(migraphx::make_op("add"), sig, l2);
sig = m2.add_instruction(migraphx::make_op("div"), sub, sig); sig = m2.add_instruction(migraphx::make_op("div"), sub, sig);
m2.add_return({sig}); m2.add_return({sig});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment