Unverified Commit c9b86f1c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Module impl (#678)



* add an api get_main_module

* clang format

* modify onnx unit test for module

* clang format

* refactor ops unit test with the get_main_module

* clang format

* code backup

* clang format

* refine module c api

* add python api for module

* clang format

* fix a python api issue

* clang format

* fix cppcheck error

* clang format

* refine unit tests changes

* clang format

* code backup

* code backup

* clang format

* defer some changes to later PRs

* change return of get_main_module from ref to pointer

* clang format

* add unit tests for the get_main_module_api

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* add more unit tests for more code change coverage

* clang format

* fixed a unit test error

* clang format

* fix unit test

* clang format

* code backup

* code change for more code coverage

* change program to module in various passes and matcher

* clang format

* modify the pass API

* code backup

* code backup

* clang format

* code backup

* clang format

* Add option to no generate a destroy method

* Formatting

* fix some review comments

* clang format

* fix review comments

* clang format

* clang format

* code backup

* code backup

* clang format

* fix cppcheck errors

* clang format

* clang format

* fix build errors

* clang format

* modify gpu unit tests to using module

* clang format

* fix cppcheck error

* clang format

* Add flag to enable cpu backend

* Make buffers shared

* Enable optimizations

* Formatting

* fix review comments

* code backup

* clang format

* code backup

* clang format

* fix a bug related to a unit test

* clang format

* clang format

* fix a build error

* remove unnecessary code

* remove unnecessary files

* code backup

* clang format

* remove the compile function from the module class

* clang format

* clang format

* remove the context parameter from the from_value method of the module class

* code refinement

* clang format

* merge changes from develop branch

* clang format

* fix cppcheck error

* clang format

* fix a build error

* fixed a merge error

* fix cppcheck error

* fixed review comments

* clang format

* fix cppcheck error

* fix a cppcheck error

* fix cppcheck error

* fix build error caused by merge

* Add missing has_op function

* Formatting

* merge changes from develop branch

* fix a cppcheck error

* fixed some review comments

* clang format

* remove the begin/end function of the program class

* clang format

* refine code and fix cppcheck error

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* add unit tests for more code coverage

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix a build error in debug mode

* clang format
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
parent 1dd4e4d9
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
namespace gpu { namespace gpu {
struct lowering struct lowering
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
namespace gpu { namespace gpu {
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
struct operation; struct operation;
namespace gpu { namespace gpu {
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
namespace gpu { namespace gpu {
......
...@@ -5,8 +5,7 @@ ...@@ -5,8 +5,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
namespace gpu { namespace gpu {
......
...@@ -129,8 +129,8 @@ struct program_model ...@@ -129,8 +129,8 @@ struct program_model
for(auto&& race : races) for(auto&& race : races)
{ {
std::cout << "Race:\n"; std::cout << "Race:\n";
p.debug_print(race.ins); mm->debug_print(race.ins);
p.debug_print(race.before); mm->debug_print(race.before);
} }
} }
}; };
......
...@@ -17,9 +17,9 @@ TEST_CASE(simple_test) ...@@ -17,9 +17,9 @@ TEST_CASE(simple_test)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
...@@ -33,9 +33,9 @@ TEST_CASE(simple_test_nop) ...@@ -33,9 +33,9 @@ TEST_CASE(simple_test_nop)
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(nop{}); mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
...@@ -51,7 +51,7 @@ TEST_CASE(simple_test_nop2) ...@@ -51,7 +51,7 @@ TEST_CASE(simple_test_nop2)
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(nop{}); mm->add_instruction(nop{});
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(mm->begin(), mm->end()) == 2);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{}); EXPECT(result == migraphx::literal{});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
...@@ -65,9 +65,9 @@ TEST_CASE(duplicate_test1) ...@@ -65,9 +65,9 @@ TEST_CASE(duplicate_test1)
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 1)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
...@@ -82,9 +82,9 @@ TEST_CASE(duplicate_test2) ...@@ -82,9 +82,9 @@ TEST_CASE(duplicate_test2)
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(minus_op{}, one, two); mm->add_instruction(minus_op{}, one, two);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 2)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
...@@ -101,9 +101,9 @@ TEST_CASE(depth_test) ...@@ -101,9 +101,9 @@ TEST_CASE(depth_test)
mm->add_instruction(minus_op{}, x1, x2); mm->add_instruction(minus_op{}, x1, x2);
mm->add_instruction(minus_op{}, x1, x2); mm->add_instruction(minus_op{}, x1, x2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 4)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 4));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
...@@ -117,9 +117,9 @@ TEST_CASE(undefined_test) ...@@ -117,9 +117,9 @@ TEST_CASE(undefined_test)
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto undef = mm->add_instruction(migraphx::make_op("undefined")); auto undef = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count - 1); EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
EXPECT(not mm->has_instruction(undef)); EXPECT(not mm->has_instruction(undef));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
...@@ -134,10 +134,10 @@ TEST_CASE(duplicate_args1) ...@@ -134,10 +134,10 @@ TEST_CASE(duplicate_args1)
auto l3 = mm->add_literal(3); auto l3 = mm->add_literal(3);
mm->add_instruction(migraphx::make_op("add"), l3, l3); mm->add_instruction(migraphx::make_op("add"), l3, l3);
mm->add_instruction(migraphx::make_op("identity"), l0); mm->add_instruction(migraphx::make_op("identity"), l0);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count); EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(mm->begin(), mm->end()) == 2);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{0}); EXPECT(result == migraphx::literal{0});
} }
...@@ -151,10 +151,10 @@ TEST_CASE(duplicate_args2) ...@@ -151,10 +151,10 @@ TEST_CASE(duplicate_args2)
auto sum1 = mm->add_instruction(migraphx::make_op("add"), l0, l3); auto sum1 = mm->add_instruction(migraphx::make_op("add"), l0, l3);
mm->add_instruction(migraphx::make_op("add"), sum1, l3); mm->add_instruction(migraphx::make_op("add"), sum1, l3);
mm->add_instruction(migraphx::make_op("identity"), l0); mm->add_instruction(migraphx::make_op("identity"), l0);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count); EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(mm->begin(), mm->end()) == 2);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{0}); EXPECT(result == migraphx::literal{0});
} }
...@@ -169,10 +169,10 @@ TEST_CASE(duplicate_args3) ...@@ -169,10 +169,10 @@ TEST_CASE(duplicate_args3)
auto sum2 = mm->add_instruction(migraphx::make_op("add"), l0, sum1); auto sum2 = mm->add_instruction(migraphx::make_op("add"), l0, sum1);
mm->add_instruction(migraphx::make_op("add"), sum2, l3); mm->add_instruction(migraphx::make_op("add"), sum2, l3);
mm->add_instruction(migraphx::make_op("identity"), l0); mm->add_instruction(migraphx::make_op("identity"), l0);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count); EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(mm->begin(), mm->end()) == 2);
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{0}); EXPECT(result == migraphx::literal{0});
} }
......
...@@ -21,9 +21,9 @@ TEST_CASE(standard_op) ...@@ -21,9 +21,9 @@ TEST_CASE(standard_op)
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_standard_op{}, c); mm->add_instruction(pass_standard_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
} }
TEST_CASE(standard_op_const) TEST_CASE(standard_op_const)
...@@ -36,7 +36,7 @@ TEST_CASE(standard_op_const) ...@@ -36,7 +36,7 @@ TEST_CASE(standard_op_const)
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_standard_op{}, c); mm->add_instruction(pass_standard_op{}, c);
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(mm->begin(), mm->end()) == 2);
} }
TEST_CASE(non_standard_op) TEST_CASE(non_standard_op)
...@@ -48,9 +48,9 @@ TEST_CASE(non_standard_op) ...@@ -48,9 +48,9 @@ TEST_CASE(non_standard_op)
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_op{}, c); mm->add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
} }
TEST_CASE(non_standard_op_const) TEST_CASE(non_standard_op_const)
...@@ -63,7 +63,7 @@ TEST_CASE(non_standard_op_const) ...@@ -63,7 +63,7 @@ TEST_CASE(non_standard_op_const)
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_op{}, c); mm->add_instruction(pass_op{}, c);
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(mm->begin(), mm->end()) == 2);
} }
TEST_CASE(transpose_gemm) TEST_CASE(transpose_gemm)
...@@ -76,9 +76,9 @@ TEST_CASE(transpose_gemm) ...@@ -76,9 +76,9 @@ TEST_CASE(transpose_gemm)
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto ic = mm->add_instruction(migraphx::make_op("identity"), c); auto ic = mm->add_instruction(migraphx::make_op("identity"), c);
mm->add_instruction(migraphx::make_op("dot"), ic, l); mm->add_instruction(migraphx::make_op("dot"), ic, l);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 1)); EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
} }
TEST_CASE(transpose_standard_op) TEST_CASE(transpose_standard_op)
...@@ -91,9 +91,9 @@ TEST_CASE(transpose_standard_op) ...@@ -91,9 +91,9 @@ TEST_CASE(transpose_standard_op)
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto sn = mm->add_instruction(migraphx::make_op("sin"), c); auto sn = mm->add_instruction(migraphx::make_op("sin"), c);
mm->add_instruction(pass_standard_op{}, sn); mm->add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
} }
TEST_CASE(transpose_standard_op_const) TEST_CASE(transpose_standard_op_const)
...@@ -107,7 +107,7 @@ TEST_CASE(transpose_standard_op_const) ...@@ -107,7 +107,7 @@ TEST_CASE(transpose_standard_op_const)
auto sn = mm->add_instruction(migraphx::make_op("sin"), c); auto sn = mm->add_instruction(migraphx::make_op("sin"), c);
mm->add_instruction(pass_standard_op{}, sn); mm->add_instruction(pass_standard_op{}, sn);
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 3); EXPECT(std::distance(mm->begin(), mm->end()) == 3);
} }
TEST_CASE(no_packed_unary_op) TEST_CASE(no_packed_unary_op)
...@@ -121,9 +121,9 @@ TEST_CASE(no_packed_unary_op) ...@@ -121,9 +121,9 @@ TEST_CASE(no_packed_unary_op)
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t); auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto sn = mm->add_instruction(migraphx::make_op("sin"), c); auto sn = mm->add_instruction(migraphx::make_op("sin"), c);
mm->add_instruction(pass_standard_op{}, sn); mm->add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count - 1); EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
} }
TEST_CASE(non_standard_return_input) TEST_CASE(non_standard_return_input)
...@@ -135,9 +135,9 @@ TEST_CASE(non_standard_return_input) ...@@ -135,9 +135,9 @@ TEST_CASE(non_standard_return_input)
auto tl = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l); auto tl = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), tl); auto c = mm->add_instruction(migraphx::make_op("contiguous"), tl);
mm->add_return({c}); mm->add_return({c});
auto count = std::distance(p.begin(), p.end()); auto count = std::distance(mm->begin(), mm->end());
run_pass(p); run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(mm->begin(), mm->end()) == count);
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -24,7 +24,7 @@ TEST_CASE(simple_test) ...@@ -24,7 +24,7 @@ TEST_CASE(simple_test)
auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two); auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two);
mm->add_instruction(sum_op{}, one_identity, two_identity); mm->add_instruction(sum_op{}, one_identity, two_identity);
run_pass(p); run_pass(p);
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
})); }));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -42,7 +42,7 @@ TEST_CASE(simple_test_end) ...@@ -42,7 +42,7 @@ TEST_CASE(simple_test_end)
auto ans = mm->add_instruction(sum_op{}, one, two); auto ans = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("identity"), ans); mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p); run_pass(p);
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
})); }));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -62,7 +62,7 @@ TEST_CASE(simple_test_end_dependency) ...@@ -62,7 +62,7 @@ TEST_CASE(simple_test_end_dependency)
mm->add_instruction(sum_op{}, ans, three); mm->add_instruction(sum_op{}, ans, three);
mm->add_instruction(migraphx::make_op("identity"), ans); mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p); run_pass(p);
EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) { EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity"; return ins.name() == "identity";
})); }));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
......
...@@ -74,8 +74,9 @@ TEST_CASE(rewrite_pad) ...@@ -74,8 +74,9 @@ TEST_CASE(rewrite_pad)
EXPECT(op1["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1}); EXPECT(op1["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(op2["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1}); EXPECT(op2["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(std::none_of( EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; })); return ins.name() == "pad";
}));
} }
TEST_CASE(rewrite_pad_im2col_asymmetric) TEST_CASE(rewrite_pad_im2col_asymmetric)
...@@ -103,8 +104,9 @@ TEST_CASE(rewrite_pad_im2col_asymmetric) ...@@ -103,8 +104,9 @@ TEST_CASE(rewrite_pad_im2col_asymmetric)
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{0, 0}); EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{0, 0});
run_pass(p); run_pass(p);
EXPECT(std::any_of( EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; })); return ins.name() == "pad";
}));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -71,7 +71,7 @@ struct reverse_pass ...@@ -71,7 +71,7 @@ struct reverse_pass
{ {
std::string name() const { return "reverse_pass"; } std::string name() const { return "reverse_pass"; }
void apply(migraphx::module& p) const { std::reverse(p.begin(), p.end()); } void apply(migraphx::module& m) const { std::reverse(m.begin(), m.end()); }
}; };
struct reverse_target struct reverse_target
...@@ -225,7 +225,7 @@ TEST_CASE(get_param1) ...@@ -225,7 +225,7 @@ TEST_CASE(get_param1)
mm->add_instruction(sum_op{}, x, y); mm->add_instruction(sum_op{}, x, y);
EXPECT(bool{p.get_parameter("x") == x}); EXPECT(bool{p.get_parameter("x") == x});
EXPECT(bool{p.get_parameter("y") == y}); EXPECT(bool{p.get_parameter("y") == y});
EXPECT(bool{p.get_parameter("nonexistent") == p.end()}); EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
} }
TEST_CASE(get_param2) TEST_CASE(get_param2)
...@@ -235,7 +235,7 @@ TEST_CASE(get_param2) ...@@ -235,7 +235,7 @@ TEST_CASE(get_param2)
auto one = mm->add_literal(1); auto one = mm->add_literal(1);
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
EXPECT(bool{p.get_parameter("nonexistent") == p.end()}); EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
} }
TEST_CASE(get_param_shapes) TEST_CASE(get_param_shapes)
...@@ -260,7 +260,7 @@ TEST_CASE(replace_test) ...@@ -260,7 +260,7 @@ TEST_CASE(replace_test)
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(sum_op{}, one, two);
mm->replace_instruction(sum, minus_op{}, two, one); mm->replace_instruction(sum, minus_op{}, two, one);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1}); EXPECT(result == migraphx::literal{1});
...@@ -276,7 +276,7 @@ TEST_CASE(replace_ins_test) ...@@ -276,7 +276,7 @@ TEST_CASE(replace_ins_test)
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(sum_op{}, one, two);
auto minus = mm->add_instruction(minus_op{}, two, one); auto minus = mm->add_instruction(minus_op{}, two, one);
mm->replace_instruction(sum, minus); mm->replace_instruction(sum, minus);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1}); EXPECT(result == migraphx::literal{1});
...@@ -293,7 +293,7 @@ TEST_CASE(replace_ins_test2) ...@@ -293,7 +293,7 @@ TEST_CASE(replace_ins_test2)
auto minus = mm->add_instruction(minus_op{}, two, one); auto minus = mm->add_instruction(minus_op{}, two, one);
mm->add_instruction(pass_op{}, minus); mm->add_instruction(pass_op{}, minus);
mm->replace_instruction(two, sum); mm->replace_instruction(two, sum);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{2}); EXPECT(result == migraphx::literal{2});
...@@ -308,7 +308,7 @@ TEST_CASE(replace_op_test) ...@@ -308,7 +308,7 @@ TEST_CASE(replace_op_test)
auto two = mm->add_literal(2); auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, two, one); auto sum = mm->add_instruction(sum_op{}, two, one);
sum->replace(minus_op{}); sum->replace(minus_op{});
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1}); EXPECT(result == migraphx::literal{1});
...@@ -336,7 +336,7 @@ TEST_CASE(insert_replace_test) ...@@ -336,7 +336,7 @@ TEST_CASE(insert_replace_test)
auto sum0 = mm->insert_instruction(sum1, sum_op{}, two, two); auto sum0 = mm->insert_instruction(sum1, sum_op{}, two, two);
mm->replace_instruction(sum1, minus_op{}, sum0, two); mm->replace_instruction(sum1, minus_op{}, sum0, two);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{4}); EXPECT(result == migraphx::literal{4});
...@@ -352,7 +352,7 @@ TEST_CASE(remove_test1) ...@@ -352,7 +352,7 @@ TEST_CASE(remove_test1)
auto sum = mm->add_instruction(sum_op{}, one, two); auto sum = mm->add_instruction(sum_op{}, one, two);
auto removed = mm->add_instruction(minus_op{}, sum, one); auto removed = mm->add_instruction(minus_op{}, sum, one);
mm->remove_instruction(removed); mm->remove_instruction(removed);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
...@@ -368,7 +368,7 @@ TEST_CASE(remove_test2) ...@@ -368,7 +368,7 @@ TEST_CASE(remove_test2)
auto removed = mm->add_instruction(minus_op{}, two, one); auto removed = mm->add_instruction(minus_op{}, two, one);
mm->add_instruction(sum_op{}, one, two); mm->add_instruction(sum_op{}, one, two);
mm->remove_instruction(removed); mm->remove_instruction(removed);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3}); EXPECT(result == migraphx::literal{3});
...@@ -509,16 +509,16 @@ TEST_CASE(debug_print_test) ...@@ -509,16 +509,16 @@ TEST_CASE(debug_print_test)
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
auto one2 = mm2->add_literal(1); auto one2 = mm2->add_literal(1);
auto program_out = migraphx::trim(capture_output([&] { p.debug_print(); })); auto program_out = migraphx::trim(capture_output([&] { mm->debug_print(); }));
auto ins_out = migraphx::trim(capture_output([&] { p.debug_print(one); })); auto ins_out = migraphx::trim(capture_output([&] { mm->debug_print(one); }));
auto inss_out = migraphx::trim(capture_output([&] { p.debug_print(onev); })); auto inss_out = migraphx::trim(capture_output([&] { mm->debug_print(onev); }));
auto end_out = migraphx::trim(capture_output([&] { p.debug_print(p.end()); })); auto end_out = migraphx::trim(capture_output([&] { mm->debug_print(mm->end()); }));
auto p2_ins_out = migraphx::trim(capture_output([&] { p.debug_print(one2); })); auto p2_ins_out = migraphx::trim(capture_output([&] { mm->debug_print(one2); }));
EXPECT(program_out == ins_out); EXPECT(program_out == ins_out);
EXPECT(inss_out == ins_out); EXPECT(inss_out == ins_out);
EXPECT(end_out == "End instruction"); EXPECT(end_out == "End instruction");
EXPECT(p2_ins_out == "Instruction not part of program"); EXPECT(p2_ins_out == "Instruction not part of module");
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -51,7 +51,7 @@ TEST_CASE(tanh_shape) ...@@ -51,7 +51,7 @@ TEST_CASE(tanh_shape)
EXPECT(p1 == p2); EXPECT(p1 == p2);
for(auto ins : iterator_for(p1)) for(auto ins : iterator_for(*p1.get_main_module()))
{ {
if(ins->name() == "hip::allocate") if(ins->name() == "hip::allocate")
{ {
......
...@@ -48,7 +48,7 @@ TEST_CASE(enable_fast_gelu) ...@@ -48,7 +48,7 @@ TEST_CASE(enable_fast_gelu)
{ {
migraphx::program p = create_gelu(); migraphx::program p = create_gelu();
p.compile(migraphx::gpu::target{}); p.compile(migraphx::gpu::target{});
CHECK(any_of(p, [&](auto&& i) { return i.name() == "gpu::gelu"; })); CHECK(any_of(*p.get_main_module(), [&](auto&& i) { return i.name() == "gpu::gelu"; }));
} }
TEST_CASE(disable_fast_gelu) TEST_CASE(disable_fast_gelu)
...@@ -57,7 +57,7 @@ TEST_CASE(disable_fast_gelu) ...@@ -57,7 +57,7 @@ TEST_CASE(disable_fast_gelu)
migraphx::compile_options options; migraphx::compile_options options;
options.fast_math = false; options.fast_math = false;
p.compile(migraphx::gpu::target{}, options); p.compile(migraphx::gpu::target{}, options);
CHECK(any_of(p, [&](auto&& i) { return i.name() == "gpu::gelu_new"; })); CHECK(any_of(*p.get_main_module(), [&](auto&& i) { return i.name() == "gpu::gelu_new"; }));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -14,7 +14,7 @@ void gpu_literal_test() ...@@ -14,7 +14,7 @@ void gpu_literal_test()
mm->add_literal(lit); mm->add_literal(lit);
p.compile(migraphx::gpu::target{}); p.compile(migraphx::gpu::target{});
auto scratch = p.get_parameter("scratch"); auto scratch = p.get_parameter("scratch");
if(scratch == p.end()) if(scratch == mm->end())
{ {
auto result = p.eval({}).back(); auto result = p.eval({}).back();
EXPECT(lit == migraphx::gpu::from_gpu(result)); EXPECT(lit == migraphx::gpu::from_gpu(result));
......
...@@ -56,7 +56,7 @@ TEST_CASE(match_name2) ...@@ -56,7 +56,7 @@ TEST_CASE(match_name2)
mm->add_instruction(pass_op{}, sum); mm->add_instruction(pass_op{}, sum);
auto m = match::name("min"); auto m = match::name("min");
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_name3) TEST_CASE(match_name3)
...@@ -98,7 +98,7 @@ TEST_CASE(match_arg2) ...@@ -98,7 +98,7 @@ TEST_CASE(match_arg2)
mm->add_instruction(pass_op{}, sum); mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape()); auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_arg3) TEST_CASE(match_arg3)
...@@ -140,7 +140,7 @@ TEST_CASE(match_arg5) ...@@ -140,7 +140,7 @@ TEST_CASE(match_arg5)
mm->add_instruction(pass_op{}, sum); mm->add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape()); auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_arg6) TEST_CASE(match_arg6)
...@@ -257,7 +257,7 @@ TEST_CASE(match_args2) ...@@ -257,7 +257,7 @@ TEST_CASE(match_args2)
auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")), auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
match::standard_shape()); match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_args3) TEST_CASE(match_args3)
...@@ -271,7 +271,7 @@ TEST_CASE(match_args3) ...@@ -271,7 +271,7 @@ TEST_CASE(match_args3)
mm->add_instruction(pass_op{}, sum); mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape()); auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_args4) TEST_CASE(match_args4)
...@@ -302,7 +302,7 @@ TEST_CASE(match_args5) ...@@ -302,7 +302,7 @@ TEST_CASE(match_args5)
auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")), auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
match::standard_shape()); match::standard_shape());
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_args6) TEST_CASE(match_args6)
...@@ -380,7 +380,7 @@ TEST_CASE(match_either_args3) ...@@ -380,7 +380,7 @@ TEST_CASE(match_either_args3)
auto m = auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal"))); match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal")));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_either_args_any1) TEST_CASE(match_either_args_any1)
...@@ -495,7 +495,7 @@ TEST_CASE(match_all_of2) ...@@ -495,7 +495,7 @@ TEST_CASE(match_all_of2)
auto m = match::name("sum")( auto m = match::name("sum")(
match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal")))); match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_all_of3) TEST_CASE(match_all_of3)
...@@ -534,7 +534,7 @@ TEST_CASE(match_lazy_all_of) ...@@ -534,7 +534,7 @@ TEST_CASE(match_lazy_all_of)
mm->add_instruction(pass_op{}, one); mm->add_instruction(pass_op{}, one);
auto m = match::all_of(match::none(), throws()); auto m = match::all_of(match::none(), throws());
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_lazy_none_of) TEST_CASE(match_lazy_none_of)
...@@ -546,7 +546,7 @@ TEST_CASE(match_lazy_none_of) ...@@ -546,7 +546,7 @@ TEST_CASE(match_lazy_none_of)
mm->add_instruction(pass_op{}, one); mm->add_instruction(pass_op{}, one);
auto m = match::none_of(match::any(), throws()); auto m = match::none_of(match::any(), throws());
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_any_of1) TEST_CASE(match_any_of1)
...@@ -576,7 +576,7 @@ TEST_CASE(match_any_of2) ...@@ -576,7 +576,7 @@ TEST_CASE(match_any_of2)
auto m = match::name("sum")( auto m = match::name("sum")(
match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum")))); match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_any_of_lazy1) TEST_CASE(match_any_of_lazy1)
...@@ -707,7 +707,7 @@ TEST_CASE(match_none_of2) ...@@ -707,7 +707,7 @@ TEST_CASE(match_none_of2)
auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")), auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal")))); match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_output1) TEST_CASE(match_output1)
...@@ -737,7 +737,7 @@ TEST_CASE(match_output2) ...@@ -737,7 +737,7 @@ TEST_CASE(match_output2)
mm->add_instruction(pass_op{}, sum); mm->add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum"))); auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_skip_output1) TEST_CASE(match_skip_output1)
...@@ -818,7 +818,7 @@ TEST_CASE(match_skip_output5) ...@@ -818,7 +818,7 @@ TEST_CASE(match_skip_output5)
mm->add_instruction(pass_op{}, sum3); mm->add_instruction(pass_op{}, sum3);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum"))); auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_skip_output6) TEST_CASE(match_skip_output6)
...@@ -934,7 +934,7 @@ TEST_CASE(match_has_value4) ...@@ -934,7 +934,7 @@ TEST_CASE(match_has_value4)
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::has_value(3); auto m = match::has_value(3);
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_has_value5) TEST_CASE(match_has_value5)
...@@ -949,7 +949,7 @@ TEST_CASE(match_has_value5) ...@@ -949,7 +949,7 @@ TEST_CASE(match_has_value5)
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3))); auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3)));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_has_value6) TEST_CASE(match_has_value6)
...@@ -964,7 +964,7 @@ TEST_CASE(match_has_value6) ...@@ -964,7 +964,7 @@ TEST_CASE(match_has_value6)
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1))); auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1)));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_tree1) TEST_CASE(match_tree1)
...@@ -996,7 +996,7 @@ TEST_CASE(match_tree2) ...@@ -996,7 +996,7 @@ TEST_CASE(match_tree2)
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(2), match::has_value(1), match::has_value(3)); auto m = match::tree("sum", match::has_value(2), match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_tree3) TEST_CASE(match_tree3)
...@@ -1029,7 +1029,7 @@ TEST_CASE(match_tree4) ...@@ -1029,7 +1029,7 @@ TEST_CASE(match_tree4)
auto m = match::tree( auto m = match::tree(
"sum", match::has_value(1), match::has_value(2), match::has_value(3), match::has_value(4)); "sum", match::has_value(1), match::has_value(2), match::has_value(3), match::has_value(4));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_tree5) TEST_CASE(match_tree5)
...@@ -1045,7 +1045,7 @@ TEST_CASE(match_tree5) ...@@ -1045,7 +1045,7 @@ TEST_CASE(match_tree5)
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(2), match::has_value(3)); auto m = match::tree("sum", match::has_value(2), match::has_value(3));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_tree6) TEST_CASE(match_tree6)
...@@ -1061,7 +1061,7 @@ TEST_CASE(match_tree6) ...@@ -1061,7 +1061,7 @@ TEST_CASE(match_tree6)
mm->add_instruction(pass_op{}, sum2); mm->add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(1), match::has_value(3)); auto m = match::tree("sum", match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
TEST_CASE(match_unordered_tree1) TEST_CASE(match_unordered_tree1)
...@@ -1129,7 +1129,7 @@ TEST_CASE(match_unordered_tree4) ...@@ -1129,7 +1129,7 @@ TEST_CASE(match_unordered_tree4)
auto m = auto m =
match::unordered_tree("sum", match::has_value(4), match::has_value(2), match::has_value(1)); match::unordered_tree("sum", match::has_value(4), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m); auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == mm->end()});
} }
struct match_find_sum struct match_find_sum
......
...@@ -42,7 +42,9 @@ migraphx::instruction_ref add_alloc(migraphx::program& p, const migraphx::shape& ...@@ -42,7 +42,9 @@ migraphx::instruction_ref add_alloc(migraphx::program& p, const migraphx::shape&
bool no_allocate(const migraphx::program& p) bool no_allocate(const migraphx::program& p)
{ {
return std::none_of(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "allocate"; }); const auto* mm = p.get_main_module();
return std::none_of(
mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "allocate"; });
} }
TEST_CASE(test1) TEST_CASE(test1)
......
#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <sstream>
#include "test.hpp"
#include <migraphx/make_op.hpp>
#include <basic_ops.hpp>
migraphx::program create_program()
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int64_type});
auto y = mm->add_parameter("y", {migraphx::shape::int64_type});
auto sum = mm->add_instruction(sum_op{}, x, y);
auto one = mm->add_literal(1);
mm->add_instruction(sum_op{}, sum, one);
return p;
}
TEST_CASE(module_ins_clear)
{
migraphx::program p1 = create_program();
migraphx::program p2;
p2 = p1;
EXPECT(p1 == p2);
}
TEST_CASE(module_print_graph)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module();
std::stringstream ss1;
mm1->print_graph(ss1, true);
std::stringstream ss2;
mm2->print_graph(ss2, true);
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(module_print_cpp)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module();
std::stringstream ss1;
mm1->print_cpp(ss1);
std::stringstream ss2;
mm2->print_cpp(ss2);
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(module_annotate)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module();
EXPECT(*mm1 == *mm2);
std::stringstream ss1;
mm1->annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
std::stringstream ss2;
mm2->annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
EXPECT(ss1.str() == ss2.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -40,18 +40,20 @@ struct normalize_test_op ...@@ -40,18 +40,20 @@ struct normalize_test_op
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
{ {
migraphx::run_passes(p, {migraphx::normalize_ops{}, migraphx::dead_code_elimination{}}); migraphx::run_passes(*p.get_main_module(),
{migraphx::normalize_ops{}, migraphx::dead_code_elimination{}});
} }
migraphx::program create_gather(int64_t axis) migraphx::program create_gather(int64_t axis)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape si{migraphx::shape::int64_type, {2, 3}}; migraphx::shape si{migraphx::shape::int64_type, {2, 3}};
auto di = p.add_parameter("data", sd); auto di = mm->add_parameter("data", sd);
auto ii = p.add_parameter("ind", si); auto ii = mm->add_parameter("ind", si);
auto r = p.add_instruction(migraphx::make_op("gather", {{"axis", axis}}), di, ii); auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), di, ii);
p.add_return({r}); mm->add_return({r});
return p; return p;
} }
...@@ -78,10 +80,11 @@ TEST_CASE(gather_test_1) ...@@ -78,10 +80,11 @@ TEST_CASE(gather_test_1)
migraphx::program create_reduce_mean(const std::vector<int64_t>& axes) migraphx::program create_reduce_mean(const std::vector<int64_t>& axes)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto si = p.add_parameter("data", s); auto si = mm->add_parameter("data", s);
auto r = p.add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), si); auto r = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), si);
p.add_return({r}); mm->add_return({r});
return p; return p;
} }
...@@ -109,11 +112,12 @@ migraphx::program create_slice(const std::vector<int64_t>& axes, ...@@ -109,11 +112,12 @@ migraphx::program create_slice(const std::vector<int64_t>& axes,
const std::vector<int64_t>& ends) const std::vector<int64_t>& ends)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto si = p.add_parameter("data", s); auto si = mm->add_parameter("data", s);
auto r = p.add_instruction( auto r = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), si); migraphx::make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), si);
p.add_return({r}); mm->add_return({r});
return p; return p;
} }
...@@ -139,10 +143,11 @@ TEST_CASE(slice_test_1) ...@@ -139,10 +143,11 @@ TEST_CASE(slice_test_1)
migraphx::program create_test_op(const std::vector<int64_t>& axes) migraphx::program create_test_op(const std::vector<int64_t>& axes)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}}; migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}};
auto di = p.add_parameter("data", sd); auto di = mm->add_parameter("data", sd);
auto r = p.add_instruction(normalize_test_op{axes}, di); auto r = mm->add_instruction(normalize_test_op{axes}, di);
p.add_return({r}); mm->add_return({r});
return p; return p;
} }
......
...@@ -1031,17 +1031,18 @@ TEST_CASE(expand_test) ...@@ -1031,17 +1031,18 @@ TEST_CASE(expand_test)
migraphx::program create_external_data_prog() migraphx::program create_external_data_prog()
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s(migraphx::shape::float_type, {1, 1, 224, 224}); migraphx::shape s(migraphx::shape::float_type, {1, 1, 224, 224});
migraphx::shape s2(migraphx::shape::float_type, {10, 1, 11, 11}); migraphx::shape s2(migraphx::shape::float_type, {10, 1, 11, 11});
std::vector<float> weight_data(1210, 1); std::vector<float> weight_data(1210, 1);
std::vector<float> bias_data(10, 1); std::vector<float> bias_data(10, 1);
auto bias = p.add_literal(migraphx::literal({migraphx::shape::float_type, {10}}, bias_data)); auto bias = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {10}}, bias_data));
auto weights = p.add_literal(migraphx::literal(s2, weight_data)); auto weights = mm->add_literal(migraphx::literal(s2, weight_data));
auto param = p.add_parameter("input", s); auto param = mm->add_parameter("input", s);
auto conv = p.add_instruction(migraphx::make_op("convolution"), param, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), param, weights);
auto bias_bcast = p.add_instruction( auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 10, 214, 214}}}), bias); migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 10, 214, 214}}}), bias);
p.add_instruction(migraphx::make_op("add"), conv, bias_bcast); mm->add_instruction(migraphx::make_op("add"), conv, bias_bcast);
return p; return p;
} }
......
...@@ -28,6 +28,8 @@ TEST_CASE(program_equality) ...@@ -28,6 +28,8 @@ TEST_CASE(program_equality)
{ {
migraphx::program x = create_program(); migraphx::program x = create_program();
migraphx::program y = create_program(); migraphx::program y = create_program();
EXPECT(x.size() == 1);
EXPECT(x == y); EXPECT(x == y);
} }
...@@ -56,6 +58,40 @@ TEST_CASE(program_default_copy_construct) ...@@ -56,6 +58,40 @@ TEST_CASE(program_default_copy_construct)
EXPECT(x == y); EXPECT(x == y);
} }
TEST_CASE(program_print)
{
migraphx::program p = create_program();
auto* mm = p.get_main_module();
auto in1 = mm->end();
// print end instruction
p.debug_print(in1);
// print instruction not in the program
auto p2 = p;
auto* mm2 = p2.get_main_module();
auto in2 = mm2->begin();
p.debug_print(in2);
// print last instruction
auto in3 = std::prev(in1);
p.debug_print(in3);
}
TEST_CASE(program_annotate)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
std::stringstream ss1;
p1.annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
std::stringstream ss2;
p2.annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(program_copy) TEST_CASE(program_copy)
{ {
auto create_program_1 = [] { auto create_program_1 = [] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment