Unverified Commit 4b86a0aa authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Exclude param from deadcode elimiation (#910)



* always keep parameters

* clang format

* fix tidy error

* clang format

* add more unit tests to have more code coverage

* fixed a bug to ensure get_parameter_names to return ordered parameter names

* clang format

* remove unnecessary print out

* refine a code change

* clang format

* add a unit test to check parameter is not removed by dead code elimination

* clang format

* rename a function name
Co-authored-by: default avatarChris Austen <causten@users.noreply.github.com>
parent 6be85674
...@@ -60,7 +60,8 @@ void dead_code_elimination::apply(module& m) const ...@@ -60,7 +60,8 @@ void dead_code_elimination::apply(module& m) const
leaf->clear_arguments(); leaf->clear_arguments();
assert(bidistance(m, last, leaf) < 0); assert(bidistance(m, last, leaf) < 0);
assert(leaf != ins); assert(leaf != ins);
m.move_instruction(leaf, m.end()); if(leaf->name() != "@param")
m.move_instruction(leaf, m.end());
for(auto arg : args) for(auto arg : args)
self(arg); self(arg);
} }
......
...@@ -125,9 +125,10 @@ void module::assign(const module& m) ...@@ -125,9 +125,10 @@ void module::assign(const module& m)
else if(ins->name() == "@param") else if(ins->name() == "@param")
{ {
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter; auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto order = any_cast<builtin::param>(ins->get_operator()).order;
auto s = ins->get_shape(); auto s = ins->get_shape();
copy_ins = copy_ins = impl->insert(impl->instructions.end(),
impl->insert(impl->instructions.end(), {builtin::param{name}, std::move(s), {}}); {builtin::param{name, order}, std::move(s), {}});
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
{ {
......
...@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void preallocate_param::apply(module& m) const void preallocate_param::apply(module& m) const
{ {
auto last = std::prev(m.end());
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "@param") if(ins->name() != "@param")
...@@ -19,7 +20,9 @@ void preallocate_param::apply(module& m) const ...@@ -19,7 +20,9 @@ void preallocate_param::apply(module& m) const
std::string id = m.name() + ":" + param; std::string id = m.name() + ":" + param;
auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id)); auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id));
m.replace_instruction(ins, r); m.replace_instruction(ins, r);
m.move_instruction(ins, m.end());
} }
m.remove_instructions(std::next(last), m.end());
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -239,6 +239,18 @@ struct stream_info ...@@ -239,6 +239,18 @@ struct stream_info
} }
} }
} }
// move dangling parameter to the front so as not be removed
auto ins = std::next(last);
while(ins != p.end())
{
auto next = std::next(ins);
if(ins->name() == "@param")
{
p.move_instruction(ins, p.begin());
}
ins = next;
}
} }
void set_stream(const partition& p, std::size_t n) void set_stream(const partition& p, std::size_t n)
...@@ -510,6 +522,9 @@ void schedule::apply(module& p) const ...@@ -510,6 +522,9 @@ void schedule::apply(module& p) const
if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
{ {
p.annotate(std::cout, [&](auto ins) { p.annotate(std::cout, [&](auto ins) {
if(ins->name() == "@param" and not contains(si.weights, ins))
return;
std::cout << ":"; std::cout << ":";
std::cout << " weight=" << si.weights.at(ins); std::cout << " weight=" << si.weights.at(ins);
std::cout << " input={"; std::cout << " input={";
...@@ -550,11 +565,9 @@ void schedule::apply(module& p) const ...@@ -550,11 +565,9 @@ void schedule::apply(module& p) const
{ {
for(auto i : si.get_recorded_instructions(ins)) for(auto i : si.get_recorded_instructions(ins))
{ {
if(not si.has_stream(i)) if(not si.has_stream(i) or si.get_stream(i) == stream)
continue;
auto istream = si.get_stream(i);
if(stream == istream)
continue; continue;
// Create a new event if it hasn't been recorded // Create a new event if it hasn't been recorded
if(not contains(ins2wait, i)) if(not contains(ins2wait, i))
{ {
......
...@@ -197,4 +197,24 @@ TEST_CASE(unused_module) ...@@ -197,4 +197,24 @@ TEST_CASE(unused_module)
EXPECT(not migraphx::contains(p.get_modules(), m1)); EXPECT(not migraphx::contains(p.get_modules(), m1));
} }
TEST_CASE(param_not_eliminated)
{
auto create_program = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {2, 2}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
mm->add_parameter("z", s);
auto sum = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_return({sum});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -175,6 +175,7 @@ TEST_CASE(inline_else_test) ...@@ -175,6 +175,7 @@ TEST_CASE(inline_else_test)
auto l2 = mm->add_literal(s, rand); auto l2 = mm->add_literal(s, rand);
mm->add_parameter("x", s); mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
mm->add_parameter("e", s);
auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2); auto r = mm->add_instruction(migraphx::make_op("mul"), y, l2);
mm->add_return({r}); mm->add_return({r});
......
...@@ -253,4 +253,27 @@ TEST_CASE(submodule_copy) ...@@ -253,4 +253,27 @@ TEST_CASE(submodule_copy)
EXPECT(mm.get_sub_modules() == mm2.get_sub_modules()); EXPECT(mm.get_sub_modules() == mm2.get_sub_modules());
} }
TEST_CASE(parameter_name_order)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module mm("main");
auto x1 = mm.add_parameter("x1", s);
auto x2 = mm.add_parameter("x2", s);
auto x3 = mm.add_parameter("x3", s);
auto x4 = mm.add_parameter("x4", s);
std::vector<std::string> param_names = {"x1", "x2", "x3", "x4"};
auto sum1 = mm.add_instruction(migraphx::make_op("add"), x1, x2);
auto sum2 = mm.add_instruction(migraphx::make_op("add"), x3, x4);
auto r = mm.add_instruction(migraphx::make_op("mul"), sum1, sum2);
mm.add_return({r});
auto names = mm.get_parameter_names();
EXPECT(param_names == names);
auto m1 = mm;
auto names1 = m1.get_parameter_names();
EXPECT(param_names == names1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -962,4 +962,23 @@ TEST_CASE(if_pl_test) ...@@ -962,4 +962,23 @@ TEST_CASE(if_pl_test)
EXPECT(t.has_stream(r2) == false); EXPECT(t.has_stream(r2) == false);
} }
TEST_CASE(unused_param_test)
{
migraphx::module mm;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
auto x = mm.add_parameter("x", s);
auto y = mm.add_parameter("y", s);
auto z = mm.add_parameter("z", s);
auto r = mm.add_instruction(migraphx::make_op("add"), x, y);
mm.add_return({r});
scheduler t{};
t.run_pass(mm);
EXPECT(t.has_stream(z) == false);
EXPECT(t.has_stream(r) == false);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -714,20 +714,21 @@ TEST_CASE(optimize_where_true) ...@@ -714,20 +714,21 @@ TEST_CASE(optimize_where_true)
return m; return m;
}; };
auto create_opt_module = [&](std::string name) { auto return_xy = [&](bool cond) {
migraphx::module m; migraphx::module m;
auto in = m.add_parameter(std::move(name), s); auto x = m.add_parameter("X", s);
m.add_return({in}); auto y = m.add_parameter("Y", s);
cond ? m.add_return({x}) : m.add_return({y});
return m; return m;
}; };
auto m = create_where_module(true); auto m = create_where_module(true);
run_pass(m); run_pass(m);
EXPECT(m == create_opt_module("X")); EXPECT(m == return_xy(true));
auto m1 = create_where_module(false); auto m1 = create_where_module(false);
run_pass(m1); run_pass(m1);
EXPECT(m1 == create_opt_module("Y")); EXPECT(m1 == return_xy(false));
} }
TEST_CASE(where_different_cond_values) TEST_CASE(where_different_cond_values)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment