Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
...@@ -2,14 +2,12 @@ ...@@ -2,14 +2,12 @@
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <test.hpp> #include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
bool is_pooling(migraphx::instruction& ins) { return ins.name() == "pooling"; } bool is_pooling(migraphx::instruction& ins) { return ins.name() == "pooling"; }
...@@ -29,7 +27,11 @@ TEST_CASE(rewrite_pooling_test) ...@@ -29,7 +27,11 @@ TEST_CASE(rewrite_pooling_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", s); auto input = mm->add_parameter("x", s);
auto ret = mm->add_instruction(migraphx::op::pooling{mode, {0, 0, 0}, {1, 1, 1}, {3, 4, 5}}, auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}),
input); input);
mm->add_return({ret}); mm->add_return({ret});
return p; return p;
...@@ -39,9 +41,10 @@ TEST_CASE(rewrite_pooling_test) ...@@ -39,9 +41,10 @@ TEST_CASE(rewrite_pooling_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", s); auto input = mm->add_parameter("x", s);
auto rsp = mm->add_instruction(migraphx::op::reshape{{4, -1}}, input); auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input);
auto rdm = mm->add_instruction(reduce_op, rsp); auto rdm = mm->add_instruction(reduce_op, rsp);
auto ret = mm->add_instruction(migraphx::op::reshape{{2, 2, 1, 1, 1}}, rdm); auto ret =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm);
mm->add_return({ret}); mm->add_return({ret});
return p; return p;
}; };
...@@ -53,8 +56,8 @@ TEST_CASE(rewrite_pooling_test) ...@@ -53,8 +56,8 @@ TEST_CASE(rewrite_pooling_test)
EXPECT(p1 == p2); EXPECT(p1 == p2);
}; };
test_rewrite("average", migraphx::op::reduce_mean{{1}}); test_rewrite("average", migraphx::make_op("reduce_mean", {{"axes", {1}}}));
test_rewrite("max", migraphx::op::reduce_max{{1}}); test_rewrite("max", migraphx::make_op("reduce_max", {{"axes", {1}}}));
} }
TEST_CASE(rewrite_avepooling_na1_test) TEST_CASE(rewrite_avepooling_na1_test)
...@@ -65,8 +68,12 @@ TEST_CASE(rewrite_avepooling_na1_test) ...@@ -65,8 +68,12 @@ TEST_CASE(rewrite_avepooling_na1_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", s); auto input = mm->add_parameter("x", s);
auto ret = mm->add_instruction( auto ret = mm->add_instruction(migraphx::make_op("pooling",
migraphx::op::pooling{"average", {0, 1, 0}, {1, 1, 1}, {3, 4, 5}}, input); {{"mode", "average"},
{"padding", {0, 1, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}),
input);
mm->add_return({ret}); mm->add_return({ret});
return p; return p;
}; };
...@@ -86,8 +93,12 @@ TEST_CASE(rewrite_avepooling_na2_test) ...@@ -86,8 +93,12 @@ TEST_CASE(rewrite_avepooling_na2_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", s); auto input = mm->add_parameter("x", s);
auto ret = mm->add_instruction( auto ret = mm->add_instruction(migraphx::make_op("pooling",
migraphx::op::pooling{"average", {0, 0, 0}, {1, 2, 1}, {3, 4, 5}}, input); {{"mode", "average"},
{"padding", {0, 0, 0}},
{"stride", {1, 2, 1}},
{"lengths", {3, 4, 5}}}),
input);
mm->add_return({ret}); mm->add_return({ret});
return p; return p;
}; };
...@@ -107,8 +118,12 @@ TEST_CASE(rewrite_avepooling_na3_test) ...@@ -107,8 +118,12 @@ TEST_CASE(rewrite_avepooling_na3_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", s); auto input = mm->add_parameter("x", s);
auto ret = mm->add_instruction( auto ret = mm->add_instruction(migraphx::make_op("pooling",
migraphx::op::pooling{"max", {0, 0, 0}, {1, 1, 1}, {3, 3, 5}}, input); {{"mode", "max"},
{"padding", {0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 3, 5}}}),
input);
mm->add_return({ret}); mm->add_return({ret});
return p; return p;
}; };
...@@ -131,7 +146,11 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -131,7 +146,11 @@ TEST_CASE(literal_rewrite_pooling_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_literal(migraphx::literal(s, data)); auto input = mm->add_literal(migraphx::literal(s, data));
auto ret = mm->add_instruction(migraphx::op::pooling{mode, {0, 0, 0}, {1, 1, 1}, {3, 4, 5}}, auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", mode},
{"padding", {0, 0, 0}},
{"stride", {1, 1, 1}},
{"lengths", {3, 4, 5}}}),
input); input);
mm->add_return({ret}); mm->add_return({ret});
return p; return p;
...@@ -141,9 +160,10 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -141,9 +160,10 @@ TEST_CASE(literal_rewrite_pooling_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_literal(migraphx::literal(s, data)); auto input = mm->add_literal(migraphx::literal(s, data));
auto rsp = mm->add_instruction(migraphx::op::reshape{{4, -1}}, input); auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), input);
auto rdm = mm->add_instruction(op, rsp); auto rdm = mm->add_instruction(op, rsp);
auto ret = mm->add_instruction(migraphx::op::reshape{{2, 2, 1, 1, 1}}, rdm); auto ret =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 1, 1, 1}}}), rdm);
mm->add_return({ret}); mm->add_return({ret});
return p; return p;
...@@ -160,8 +180,8 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -160,8 +180,8 @@ TEST_CASE(literal_rewrite_pooling_test)
result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}; };
test_rewrite_pooling("max", migraphx::op::reduce_max{{1}}); test_rewrite_pooling("max", migraphx::make_op("reduce_max", {{"axes", {1}}}));
test_rewrite_pooling("average", migraphx::op::reduce_mean{{1}}); test_rewrite_pooling("average", migraphx::make_op("reduce_mean", {{"axes", {1}}}));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment