Unverified Commit b847e868 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Bug split optimization (#817)



* backup implementation of resize enhancement

* clang format

* code backup for the resize

* clang format

* fix build error for resize operator

* clang format

* tmp code backup

* clang format

* remove changes in parse_resize

* remove unnecessary changes

* clang format

* add unit test for the bug

* clang format

* remove print code

* remove a semi-colon

* clang format

* fix a tidy error

* fix review comments

* clang format
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 7ab06956
......@@ -20,6 +20,7 @@
#include <migraphx/serialize.hpp>
#include <migraphx/algorithm.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -403,8 +404,27 @@ struct find_splits
match::any_of[match::outputs()](match::pointwise(), reduction()))));
}
static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2)
{
std::unordered_set<instruction_ref> traversed;
return fix<bool>([&](auto self, auto ins) -> bool {
if(ins == ins2)
return true;
if(contains(traversed, ins))
return false;
traversed.insert(ins);
const auto& inputs = ins->inputs();
return std::any_of(inputs.begin(), inputs.end(), [&](auto in) {
return m.has_instruction(in) and self(in);
});
})(ins1);
}
static std::vector<std::vector<instruction_ref>>
get_split_groups(const std::vector<instruction_ref>& splits)
get_split_groups(const module& m, const std::vector<instruction_ref>& splits)
{
std::vector<std::vector<instruction_ref>> groups;
for(auto out : splits.front()->outputs())
......@@ -421,9 +441,16 @@ struct find_splits
if(it == split->outputs().end())
break;
assert((*it)->name() != "slice");
// If there is a duplicate bail
if(contains(group, *it))
// there are should be no dependency between instructions in the group
if(std::any_of(group.begin(), group.end(), [&](auto i) {
return is_dependent(m, *it, i) or is_dependent(m, i, *it);
}))
{
return {};
}
group.push_back(*it);
}
if(group.size() != splits.size())
......@@ -460,13 +487,12 @@ struct find_splits
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins = r.result;
auto splits = get_splits(ins);
if(splits.empty())
return;
for(const auto& group : get_split_groups(splits))
for(const auto& group : get_split_groups(p, splits))
{
auto start = group.front();
auto split_front = splits.front();
......
......@@ -2132,4 +2132,32 @@ TEST_CASE(reorder_slice_trans_diff_perm)
test(4);
}
TEST_CASE(reorder_slice_ins_deps)
{
auto create_module = [] {
migraphx::module m;
migraphx::shape sx{migraphx::shape::float_type, {4, 2}};
migraphx::shape sy{migraphx::shape::float_type, {2, 2}};
std::vector<float> datax = {0, 1, 2, 3, 4, 5, 6, 7};
std::vector<float> datay = {0, 1, 2, 3};
auto inx = m.add_literal(migraphx::literal(sx, datax));
auto iny = m.add_literal(migraphx::literal(sy, datay));
auto slc0 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), inx);
auto slc1 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), inx);
auto n0 = m.add_instruction(migraphx::make_op("neg"), slc0);
auto a0 = m.add_instruction(migraphx::make_op("add"), n0, slc1);
auto m0 = m.add_instruction(migraphx::make_op("mul"), a0, iny);
auto r = m.add_instruction(migraphx::make_op("add"), m0, slc0);
m.add_return({r});
return m;
};
auto m = create_module();
run_pass(m);
EXPECT(m == create_module());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment