Commit 9ccf5e5c authored by turneram's avatar turneram
Browse files

Use common.hpp to add broadcasts

parent b07f5e4e
......@@ -26,6 +26,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -42,15 +43,11 @@ struct find_gelu_erf
return;
auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}});
auto mul = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lit);
mul = m.insert_instruction(ins, make_op("mul"), x, mul);
auto mul = insert_common_op(m, ins, make_op("mul"), {x, lit});
auto sig = m.insert_instruction(ins, make_op("neg"), mul);
sig = m.insert_instruction(ins, make_op("exp"), sig);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
one = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), one);
sig = m.insert_instruction(ins, make_op("add"), sig, one);
sig = insert_common_op(m, ins, make_op("add"), {sig, one});
sig = m.insert_instruction(ins, make_op("div"), x, sig);
m.replace_instruction(ins, sig);
}
......
......@@ -32,6 +32,7 @@
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/serialize.hpp>
......@@ -47,16 +48,13 @@ TEST_CASE(bias_gelu)
auto b = m1.add_parameter("b", s1);
auto add1 = m1.add_instruction(migraphx::make_op("add"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
l1 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1);
auto div = m1.add_instruction(migraphx::make_op("div"), add1, l1);
auto div = add_common_op(m1, migraphx::make_op("div"), {add1, l1});
auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}});
l2 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l2);
auto add2 = m1.add_instruction(migraphx::make_op("add"), erf, l2);
auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2});
auto mul = m1.add_instruction(migraphx::make_op("mul"), add1, add2);
auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}});
l3 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l3);
mul = m1.add_instruction(migraphx::make_op("mul"), mul, l3);
mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3});
m1.add_return({mul});
}
migraphx::rewrite_gelu pass;
......@@ -70,13 +68,11 @@ TEST_CASE(bias_gelu)
auto b = m2.add_parameter("b", s1);
auto add = m2.add_instruction(migraphx::make_op("add"), a, b);
auto l1 = m2.add_literal(migraphx::literal{s2, {1.702f}});
l1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1);
auto mul = m2.add_instruction(migraphx::make_op("mul"), add, l1);
auto mul = add_common_op(m2, migraphx::make_op("mul"), {add, l1});
auto sig = m2.add_instruction(migraphx::make_op("neg"), mul);
sig = m2.add_instruction(migraphx::make_op("exp"), sig);
auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}});
l2 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l2);
sig = m2.add_instruction(migraphx::make_op("add"), sig, l2);
sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2});
sig = m2.add_instruction(migraphx::make_op("div"), add, sig);
m2.add_return({sig});
}
......@@ -94,16 +90,13 @@ TEST_CASE(non_bias_gelu)
auto b = m1.add_parameter("b", s1);
auto sub = m1.add_instruction(migraphx::make_op("sub"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
l1 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1);
auto div = m1.add_instruction(migraphx::make_op("div"), sub, l1);
auto div = add_common_op(m1, migraphx::make_op("div"), {sub, l1});
auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}});
l2 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l2);
auto add2 = m1.add_instruction(migraphx::make_op("add"), erf, l2);
auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2});
auto mul = m1.add_instruction(migraphx::make_op("mul"), sub, add2);
auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}});
l3 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l3);
mul = m1.add_instruction(migraphx::make_op("mul"), mul, l3);
mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3});
m1.add_return({mul});
}
migraphx::rewrite_gelu pass;
......@@ -117,13 +110,11 @@ TEST_CASE(non_bias_gelu)
auto b = m2.add_parameter("b", s1);
auto sub = m2.add_instruction(migraphx::make_op("sub"), a, b);
auto l1 = m2.add_literal(migraphx::literal{s2, {1.702f}});
l1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1);
auto mul = m2.add_instruction(migraphx::make_op("mul"), sub, l1);
auto mul = add_common_op(m2, migraphx::make_op("mul"), {sub, l1});
auto sig = m2.add_instruction(migraphx::make_op("neg"), mul);
sig = m2.add_instruction(migraphx::make_op("exp"), sig);
auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}});
l2 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l2);
sig = m2.add_instruction(migraphx::make_op("add"), sig, l2);
sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2});
sig = m2.add_instruction(migraphx::make_op("div"), sub, sig);
m2.add_return({sig});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment