Unverified Commit 2ea40daa authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix bert fusions (#666)



* Fix fusions in bert model

* Formatting

* Add unit tests

* Formatting

* Fix one_half matcher

* Workaround ICE on gcc

* Formatting

* Tidy fixes
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent f279cda6
...@@ -131,10 +131,22 @@ auto pack(Ts... xs) ...@@ -131,10 +131,22 @@ auto pack(Ts... xs)
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
inline auto pack_join() { return pack(); }
template <class P, class... Ps>
auto pack_join(P p, Ps... ps)
{
return [=](auto f) {
return p([&](auto... xs) {
return pack_join(ps...)([&](auto... ys) { return f(xs..., ys...); });
});
};
}
template <class F, class T> template <class F, class T>
auto fold_impl(F&&, T&& x) auto fold_impl(F&&, T&& x)
{ {
return x; return std::forward<T>(x);
} }
template <class F, class T, class U, class... Ts> template <class F, class T, class U, class... Ts>
......
...@@ -80,6 +80,8 @@ struct instruction ...@@ -80,6 +80,8 @@ struct instruction
static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false); static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false);
void debug_print() const;
private: private:
// internal // internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args); void replace(operation o, const shape& r, std::vector<instruction_ref> args);
......
...@@ -28,6 +28,12 @@ struct matcher_context ...@@ -28,6 +28,12 @@ struct matcher_context
return m.match(*this, ins) != this->not_found(); return m.match(*this, ins) != this->not_found();
} }
template <class M>
auto lazy_match(M m, instruction_ref ins)
{
return [=] { return this->matched(m, ins); };
}
private: private:
instruction_ref last; instruction_ref last;
}; };
...@@ -270,18 +276,18 @@ find_skip<M> make_find_skip(M m) ...@@ -270,18 +276,18 @@ find_skip<M> make_find_skip(M m)
struct lazy_and struct lazy_and
{ {
template <class F, class G> template <class F, class G>
bool operator()(F f, G g) const auto operator()(F f, G g) const
{ {
return f() and g(); return [=] { return f() and g(); };
} }
}; };
struct lazy_or struct lazy_or
{ {
template <class F, class G> template <class F, class G>
bool operator()(F f, G g) const auto operator()(F f, G g) const
{ {
return f() or g(); return [=] { return f() or g(); };
} }
}; };
...@@ -293,7 +299,7 @@ struct match_fold_f ...@@ -293,7 +299,7 @@ struct match_fold_f
{ {
Op op; Op op;
auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; }; auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; };
return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...); return fold(op)(always(Start), matched(ms)...)();
} }
template <class Pack> template <class Pack>
...@@ -324,7 +330,7 @@ struct match_fold_f ...@@ -324,7 +330,7 @@ struct match_fold_f
bool matches = Start; bool matches = Start;
select(start, [&](auto ins) { select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); }; auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm); matches = op(always(matches), fm)();
}); });
if(matches == Matches) if(matches == Matches)
return start; return start;
...@@ -535,6 +541,66 @@ inline auto any_arg(std::size_t i, std::size_t j) ...@@ -535,6 +541,66 @@ inline auto any_arg(std::size_t i, std::size_t j)
return [=](auto m) { return match::any_of(arg(i)(m), arg(j)(m)); }; return [=](auto m) { return match::any_of(arg(i)(m), arg(j)(m)); };
} }
template <std::size_t N>
std::size_t
tree_leafs_impl(std::array<instruction_ref, N>& leafs, const std::string& s, instruction_ref ins)
{
std::size_t idx = 0;
fix([&](auto self, auto i) {
if(idx == leafs.size())
return;
if(i->name() == s and i->inputs().size() >= 2)
{
self(i->inputs()[0]);
self(i->inputs()[1]);
return;
}
leafs[idx] = i;
idx++;
})(ins);
return idx;
}
template <class... Ms>
auto tree(std::string s, Ms... ms)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(leafs, s, ins);
if(idx != leafs.size())
return ctx.not_found();
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
});
if(not found)
return ctx.not_found();
return ins;
});
}
template <class... Ms>
auto unordered_tree(std::string s, Ms... ms)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(leafs, s, ins);
if(idx != leafs.size())
return ctx.not_found();
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
})(ms...)();
});
if(not found)
return ctx.not_found();
return ins;
});
}
template <class M> template <class M>
auto same_shape(M m) auto same_shape(M m)
{ {
...@@ -563,7 +629,8 @@ inline auto has_value(T x, float tolerance = 1e-6) ...@@ -563,7 +629,8 @@ inline auto has_value(T x, float tolerance = 1e-6)
return false; return false;
bool b = false; bool b = false;
l.visit([&](auto v) { l.visit([&](auto v) {
if(std::all_of(v.begin(), v.end(), [&](auto val) { return val - x < tolerance; })) if(std::all_of(
v.begin(), v.end(), [&](auto val) { return std::fabs(val - x) < tolerance; }))
b = true; b = true;
}); });
return b; return b;
......
...@@ -219,6 +219,37 @@ void instruction::finalize(context& ctx) ...@@ -219,6 +219,37 @@ void instruction::finalize(context& ctx)
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs())); this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
} }
static void debug_name(std::ostream& os, const instruction& ins)
{
if(ins.name() == "@literal")
{
os << "@literal";
if(ins.get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins.get_literal() << "}";
}
else
{
os << ins.get_operator();
}
}
void instruction::debug_print() const
{
debug_name(std::cout, *this);
std::string delim = "(";
for(auto arg : this->inputs())
{
std::cout << delim;
debug_name(std::cout, *arg);
delim = ", ";
}
if(not this->inputs().empty())
std::cout << ")";
std::cout << " -> " << this->get_shape() << std::endl;
}
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow) instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
{ {
auto i = ins->get_operator().output_alias(to_shapes(ins->inputs())); auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
......
...@@ -105,6 +105,7 @@ struct find_nop_reshapes ...@@ -105,6 +105,7 @@ struct find_nop_reshapes
reshapes.insert("as_shape"); reshapes.insert("as_shape");
reshapes.insert("broadcast"); reshapes.insert("broadcast");
reshapes.insert("concat"); reshapes.insert("concat");
reshapes.insert("convert");
reshapes.insert("multibroadcast"); reshapes.insert("multibroadcast");
reshapes.insert("pad"); reshapes.insert("pad");
reshapes.insert("slice"); reshapes.insert("slice");
......
...@@ -354,13 +354,18 @@ struct find_gelu ...@@ -354,13 +354,18 @@ struct find_gelu
match::has_value(M_SQRT1_2))))); match::has_value(M_SQRT1_2)))));
} }
static auto add_erf()
{
return match::name("gpu::add")(
match::used_once(),
match::either_arg(0, 1)(erf_fn(), match::args(match::has_value(1.0f))));
}
static auto one_half() { return match::args(match::has_value(0.5f)); }
auto matcher() const auto matcher() const
{ {
return match::name("gpu::mul")(match::either_arg(0, 1)( return match::unordered_tree("gpu::mul", one_half(), add_erf(), match::any());
match::name("gpu::mul")(match::any_arg(0, 1)(match::args(match::has_value(0.5f)))),
match::name("gpu::add")(
match::used_once(),
match::either_arg(0, 1)(erf_fn(), match::args(match::has_value(1.0f))))));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
......
...@@ -767,6 +767,229 @@ TEST_CASE(match_bind1) ...@@ -767,6 +767,229 @@ TEST_CASE(match_bind1)
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
TEST_CASE(match_has_value1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::has_value(1);
auto r = find_match(p, m);
EXPECT(bool{r.result == one});
}
TEST_CASE(match_has_value2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::has_value(2);
auto r = find_match(p, m);
EXPECT(bool{r.result == two});
}
TEST_CASE(match_has_value3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(2)));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum1});
}
TEST_CASE(match_has_value4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::has_value(3);
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_has_value5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3)));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_has_value6)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1)));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_tree1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto three = p.add_literal(3);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, three);
p.add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(1), match::has_value(2), match::has_value(3));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_tree2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto three = p.add_literal(3);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, three);
p.add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(2), match::has_value(1), match::has_value(3));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_tree3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto three = p.add_literal(3);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, three, sum1);
p.add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(3), match::has_value(1), match::has_value(2));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_tree4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto three = p.add_literal(3);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, three);
p.add_instruction(pass_op{}, sum2);
auto m = match::tree(
"sum", match::has_value(1), match::has_value(2), match::has_value(3), match::has_value(4));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_tree5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto three = p.add_literal(3);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, three);
p.add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(2), match::has_value(3));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_tree6)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto three = p.add_literal(3);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, three);
p.add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(1), match::has_value(3));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_unordered_tree1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto three = p.add_literal(3);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, three);
p.add_instruction(pass_op{}, sum2);
auto m =
match::unordered_tree("sum", match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_unordered_tree2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto three = p.add_literal(3);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, three, sum1);
p.add_instruction(pass_op{}, sum2);
auto m =
match::unordered_tree("sum", match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_unordered_tree3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto three = p.add_literal(3);
auto sum1 = p.add_instruction(sum_op{}, two, one);
auto sum2 = p.add_instruction(sum_op{}, sum1, three);
p.add_instruction(pass_op{}, sum2);
auto m =
match::unordered_tree("sum", match::has_value(3), match::has_value(2), match::has_value(1));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
}
TEST_CASE(match_unordered_tree4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto three = p.add_literal(3);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, three);
p.add_instruction(pass_op{}, sum2);
auto m =
match::unordered_tree("sum", match::has_value(4), match::has_value(2), match::has_value(1));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
struct match_find_sum struct match_find_sum
{ {
migraphx::instruction_ref ins; migraphx::instruction_ref ins;
......
...@@ -259,6 +259,20 @@ TEST_CASE(nop_transpose3) ...@@ -259,6 +259,20 @@ TEST_CASE(nop_transpose3)
EXPECT(std::distance(p.begin(), p.end()) == n - 1); EXPECT(std::distance(p.begin(), p.end()) == n - 1);
} }
TEST_CASE(nop_convert)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, x);
p.add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(concat_transpose1) TEST_CASE(concat_transpose1)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment