Unverified Commit 0fa539da authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Bug simplify algegra (#753)



* fix issue#727

* clang format

* refine unit tests

* fix cppcheck error

* fix review comments

* refine a unit test to cover more code changes

* fix cppcheck error

* remove unnecessary include file

* fix review comments

* clang format
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent d0e8bb1a
......@@ -480,6 +480,12 @@ inline auto name(std::string s)
[ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
}
inline auto name_contains(const std::string& name)
{
return make_basic_pred_matcher(
[=](instruction_ref ins) { return contains(ins->get_operator().name(), name); });
}
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
......@@ -643,6 +649,7 @@ inline auto has_attribute(const std::string& name)
[=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -46,6 +46,8 @@ auto pointwise(Ms... ms)
ms...);
}
auto reduction() { return match::name_contains("reduce"); }
struct find_mul_conv
{
auto matcher() const
......@@ -406,7 +408,7 @@ struct find_splits
auto matcher() const
{
return match::any(match::any_of[match::outputs()](
match::name("slice")(match::any_of[match::outputs()](pointwise()))));
match::name("slice")(match::any_of[match::outputs()](pointwise(), reduction()))));
}
static std::vector<std::vector<instruction_ref>>
......@@ -439,6 +441,31 @@ struct find_splits
return groups;
}
bool is_fusable(instruction_ref start, instruction_ref split_front) const
{
auto op = start->get_operator();
if(contains(op.name(), "reduce"))
{
auto slc = any_cast<op::slice>(split_front->get_operator());
auto slc_axes = slc.axes;
auto reduce_axes = start->get_operator().to_value()["axes"].to_vector<int64_t>();
// axes of slice and reduce op cannot have overlap
if(std::any_of(slc_axes.begin(), slc_axes.end(), [&](auto axis) {
return (std::find(reduce_axes.begin(), reduce_axes.end(), axis) !=
reduce_axes.end());
}))
{
return false;
}
}
else if(not op.attributes().contains("pointwise"))
{
return false;
}
return true;
}
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -446,12 +473,16 @@ struct find_splits
auto splits = get_splits(ins);
if(splits.empty())
return;
for(const auto& group : get_split_groups(splits))
{
auto start = group.front();
auto op = start->get_operator();
if(op.name() == "slice")
auto start = group.front();
auto split_front = splits.front();
auto op = start->get_operator();
if(not is_fusable(start, split_front))
{
continue;
}
// Make sure there is no duplicates
assert(std::none_of(
......
......@@ -918,6 +918,113 @@ TEST_CASE(simplify_split_add_relu)
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_split_reduce0)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = m1.add_literal(1);
auto two = m1.add_literal(2);
auto arx = m1.add_instruction(migraphx::make_op("contiguous"), x);
auto ary = m1.add_instruction(migraphx::make_op("contiguous"), y);
auto rmax0 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1}}}), x);
auto rmin0 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), x);
auto rmax1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), arx, one);
auto rmin1 = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), ary, two);
auto rmax2 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1}}}), y);
auto rmin2 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), y);
m1.add_return({rmax0, rmin0, rmax1, rmin1, rmax2, rmin2});
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_split_reduce1)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto rmax0 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), x);
auto rmin0 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 2}}}), x);
auto rmax2 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), y);
auto rmin2 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 2}}}), y);
m1.add_return({rmax0, rmin0, rmax2, rmin2});
}
migraphx::module m2;
{
auto input = m2.add_parameter("input", s);
auto rmn = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 2}}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rmn);
auto rmx = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), input);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rmx);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rmn);
auto slc3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rmx);
m2.add_return({slc3, slc2, slc1, slc0});
}
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_split_reduce2)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto rmax0 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), x);
auto rmin0 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), x);
auto rmax2 = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), y);
auto rmin2 = m1.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), y);
m1.add_return({rmax0, rmin0, rmax2, rmin2});
}
migraphx::module m2;
{
auto input = m2.add_parameter("input", s);
auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto rmn1 = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), x);
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto rmn2 = m2.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {0, 1}}}), y);
auto rms = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 2}}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rms);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rms);
m2.add_return({slc1, rmn2, slc0, rmn1});
}
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_split_add_relu_reshape)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment