Unverified Commit c9b86f1c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Module impl (#678)



* add an api get_main_module

* clang format

* modify onnx unit test for module

* clang format

* refactor ops unit test with the get_main_module

* clang format

* code backup

* clang format

* refine module c api

* add python api for module

* clang format

* fix a python api issue

* clang format

* fix cppcheck error

* clang format

* refine unit tests changes

* clang format

* code backup

* code backup

* clang format

* defer some changes to later PRs

* change return of get_main_module from ref to pointer

* clang format

* add unit tests for the get_main_module_api

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* add more unit tests for more code change coverage

* clang format

* fixed a unit test error

* clang format

* fix unit test

* clang format

* code backup

* code change for more code coverage

* change program to module in various passes and matcher

* clang format

* modify the pass API

* code backup

* code backup

* clang format

* code backup

* clang format

* Add option to no generate a destroy method

* Formatting

* fix some review comments

* clang format

* fix review comments

* clang format

* clang format

* code backup

* code backup

* clang format

* fix cppcheck errors

* clang format

* clang format

* fix build errors

* clang format

* modify gpu unit tests to using module

* clang format

* fix cppcheck error

* clang format

* Add flag to enable cpu backend

* Make buffers shared

* Enable optimizations

* Formatting

* fix review comments

* code backup

* clang format

* code backup

* clang format

* fix a bug related to a unit test

* clang format

* clang format

* fix a build error

* remove unnecessary code

* remove unnecessary files

* code backup

* clang format

* remove the compile function from the module class

* clang format

* clang format

* remove the context parameter from the from_value method of the module class

* code refinement

* clang format

* merge changes from develop branch

* clang format

* fix cppcheck error

* clang format

* fix a build error

* fixed a merge error

* fix cppcheck error

* fixed review comments

* clang format

* fix cppcheck error

* fix a cppcheck error

* fix cppcheck error

* fix build error caused by merge

* Add missing has_op function

* Formatting

* merge changes from develop branch

* fix a cppcheck error

* fixed some review comments

* clang format

* remove the begin/end function of the program class

* clang format

* refine code and fix cppcheck error

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* add unit tests for more code coverage

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix a build error in debug mode

* clang format
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
parent 1dd4e4d9
......@@ -7,8 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
namespace gpu {
struct lowering
......
......@@ -8,8 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
namespace gpu {
......
......@@ -8,8 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
struct operation;
namespace gpu {
......
......@@ -8,8 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
namespace gpu {
......
......@@ -5,8 +5,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
namespace gpu {
......
......@@ -129,8 +129,8 @@ struct program_model
for(auto&& race : races)
{
std::cout << "Race:\n";
p.debug_print(race.ins);
p.debug_print(race.before);
mm->debug_print(race.ins);
mm->debug_print(race.before);
}
}
};
......
......@@ -17,9 +17,9 @@ TEST_CASE(simple_test)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......@@ -33,9 +33,9 @@ TEST_CASE(simple_test_nop)
auto two = mm->add_literal(2);
mm->add_instruction(nop{});
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......@@ -51,7 +51,7 @@ TEST_CASE(simple_test_nop2)
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(nop{});
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{});
EXPECT(result != migraphx::literal{4});
......@@ -65,9 +65,9 @@ TEST_CASE(duplicate_test1)
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......@@ -82,9 +82,9 @@ TEST_CASE(duplicate_test2)
mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(minus_op{}, one, two);
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 2));
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2));
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......@@ -101,9 +101,9 @@ TEST_CASE(depth_test)
mm->add_instruction(minus_op{}, x1, x2);
mm->add_instruction(minus_op{}, x1, x2);
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 4));
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 4));
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
......@@ -117,9 +117,9 @@ TEST_CASE(undefined_test)
auto two = mm->add_literal(2);
auto undef = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
EXPECT(not mm->has_instruction(undef));
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
......@@ -134,10 +134,10 @@ TEST_CASE(duplicate_args1)
auto l3 = mm->add_literal(3);
mm->add_instruction(migraphx::make_op("add"), l3, l3);
mm->add_instruction(migraphx::make_op("identity"), l0);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2);
EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{0});
}
......@@ -151,10 +151,10 @@ TEST_CASE(duplicate_args2)
auto sum1 = mm->add_instruction(migraphx::make_op("add"), l0, l3);
mm->add_instruction(migraphx::make_op("add"), sum1, l3);
mm->add_instruction(migraphx::make_op("identity"), l0);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2);
EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{0});
}
......@@ -169,10 +169,10 @@ TEST_CASE(duplicate_args3)
auto sum2 = mm->add_instruction(migraphx::make_op("add"), l0, sum1);
mm->add_instruction(migraphx::make_op("add"), sum2, l3);
mm->add_instruction(migraphx::make_op("identity"), l0);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2);
EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{0});
}
......
......@@ -21,9 +21,9 @@ TEST_CASE(standard_op)
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_standard_op{}, c);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
}
TEST_CASE(standard_op_const)
......@@ -36,7 +36,7 @@ TEST_CASE(standard_op_const)
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_standard_op{}, c);
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
}
TEST_CASE(non_standard_op)
......@@ -48,9 +48,9 @@ TEST_CASE(non_standard_op)
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_op{}, c);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
}
TEST_CASE(non_standard_op_const)
......@@ -63,7 +63,7 @@ TEST_CASE(non_standard_op_const)
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_instruction(pass_op{}, c);
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 2);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
}
TEST_CASE(transpose_gemm)
......@@ -76,9 +76,9 @@ TEST_CASE(transpose_gemm)
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto ic = mm->add_instruction(migraphx::make_op("identity"), c);
mm->add_instruction(migraphx::make_op("dot"), ic, l);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1));
}
TEST_CASE(transpose_standard_op)
......@@ -91,9 +91,9 @@ TEST_CASE(transpose_standard_op)
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto sn = mm->add_instruction(migraphx::make_op("sin"), c);
mm->add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
}
TEST_CASE(transpose_standard_op_const)
......@@ -107,7 +107,7 @@ TEST_CASE(transpose_standard_op_const)
auto sn = mm->add_instruction(migraphx::make_op("sin"), c);
mm->add_instruction(pass_standard_op{}, sn);
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == 3);
EXPECT(std::distance(mm->begin(), mm->end()) == 3);
}
TEST_CASE(no_packed_unary_op)
......@@ -121,9 +121,9 @@ TEST_CASE(no_packed_unary_op)
auto c = mm->add_instruction(migraphx::make_op("contiguous"), t);
auto sn = mm->add_instruction(migraphx::make_op("sin"), c);
mm->add_instruction(pass_standard_op{}, sn);
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count - 1);
EXPECT(std::distance(mm->begin(), mm->end()) == count - 1);
}
TEST_CASE(non_standard_return_input)
......@@ -135,9 +135,9 @@ TEST_CASE(non_standard_return_input)
auto tl = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), tl);
mm->add_return({c});
auto count = std::distance(p.begin(), p.end());
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == count);
EXPECT(std::distance(mm->begin(), mm->end()) == count);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -24,7 +24,7 @@ TEST_CASE(simple_test)
auto two_identity = mm->add_instruction(migraphx::make_op("identity"), two);
mm->add_instruction(sum_op{}, one_identity, two_identity);
run_pass(p);
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({}).back();
......@@ -42,7 +42,7 @@ TEST_CASE(simple_test_end)
auto ans = mm->add_instruction(sum_op{}, one, two);
mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p);
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({}).back();
......@@ -62,7 +62,7 @@ TEST_CASE(simple_test_end_dependency)
mm->add_instruction(sum_op{}, ans, three);
mm->add_instruction(migraphx::make_op("identity"), ans);
run_pass(p);
EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({}).back();
......
......@@ -74,8 +74,9 @@ TEST_CASE(rewrite_pad)
EXPECT(op1["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(op2["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{1, 1});
EXPECT(std::none_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
EXPECT(std::none_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "pad";
}));
}
TEST_CASE(rewrite_pad_im2col_asymmetric)
......@@ -103,8 +104,9 @@ TEST_CASE(rewrite_pad_im2col_asymmetric)
EXPECT(op0["padding"].to_vector<std::size_t>() == std::vector<std::size_t>{0, 0});
run_pass(p);
EXPECT(std::any_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) {
return ins.name() == "pad";
}));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -71,7 +71,7 @@ struct reverse_pass
{
std::string name() const { return "reverse_pass"; }
void apply(migraphx::module& p) const { std::reverse(p.begin(), p.end()); }
void apply(migraphx::module& m) const { std::reverse(m.begin(), m.end()); }
};
struct reverse_target
......@@ -225,7 +225,7 @@ TEST_CASE(get_param1)
mm->add_instruction(sum_op{}, x, y);
EXPECT(bool{p.get_parameter("x") == x});
EXPECT(bool{p.get_parameter("y") == y});
EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
}
TEST_CASE(get_param2)
......@@ -235,7 +235,7 @@ TEST_CASE(get_param2)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
EXPECT(bool{p.get_parameter("nonexistent") == mm->end()});
}
TEST_CASE(get_param_shapes)
......@@ -260,7 +260,7 @@ TEST_CASE(replace_test)
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, one, two);
mm->replace_instruction(sum, minus_op{}, two, one);
EXPECT(bool{p.validate() == p.end()});
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1});
......@@ -276,7 +276,7 @@ TEST_CASE(replace_ins_test)
auto sum = mm->add_instruction(sum_op{}, one, two);
auto minus = mm->add_instruction(minus_op{}, two, one);
mm->replace_instruction(sum, minus);
EXPECT(bool{p.validate() == p.end()});
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1});
......@@ -293,7 +293,7 @@ TEST_CASE(replace_ins_test2)
auto minus = mm->add_instruction(minus_op{}, two, one);
mm->add_instruction(pass_op{}, minus);
mm->replace_instruction(two, sum);
EXPECT(bool{p.validate() == p.end()});
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{2});
......@@ -308,7 +308,7 @@ TEST_CASE(replace_op_test)
auto two = mm->add_literal(2);
auto sum = mm->add_instruction(sum_op{}, two, one);
sum->replace(minus_op{});
EXPECT(bool{p.validate() == p.end()});
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{1});
......@@ -336,7 +336,7 @@ TEST_CASE(insert_replace_test)
auto sum0 = mm->insert_instruction(sum1, sum_op{}, two, two);
mm->replace_instruction(sum1, minus_op{}, sum0, two);
EXPECT(bool{p.validate() == p.end()});
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{4});
......@@ -352,7 +352,7 @@ TEST_CASE(remove_test1)
auto sum = mm->add_instruction(sum_op{}, one, two);
auto removed = mm->add_instruction(minus_op{}, sum, one);
mm->remove_instruction(removed);
EXPECT(bool{p.validate() == p.end()});
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
......@@ -368,7 +368,7 @@ TEST_CASE(remove_test2)
auto removed = mm->add_instruction(minus_op{}, two, one);
mm->add_instruction(sum_op{}, one, two);
mm->remove_instruction(removed);
EXPECT(bool{p.validate() == p.end()});
EXPECT(bool{p.validate() == mm->end()});
auto result = p.eval({}).back();
EXPECT(result == migraphx::literal{3});
......@@ -509,16 +509,16 @@ TEST_CASE(debug_print_test)
auto* mm2 = p2.get_main_module();
auto one2 = mm2->add_literal(1);
auto program_out = migraphx::trim(capture_output([&] { p.debug_print(); }));
auto ins_out = migraphx::trim(capture_output([&] { p.debug_print(one); }));
auto inss_out = migraphx::trim(capture_output([&] { p.debug_print(onev); }));
auto end_out = migraphx::trim(capture_output([&] { p.debug_print(p.end()); }));
auto p2_ins_out = migraphx::trim(capture_output([&] { p.debug_print(one2); }));
auto program_out = migraphx::trim(capture_output([&] { mm->debug_print(); }));
auto ins_out = migraphx::trim(capture_output([&] { mm->debug_print(one); }));
auto inss_out = migraphx::trim(capture_output([&] { mm->debug_print(onev); }));
auto end_out = migraphx::trim(capture_output([&] { mm->debug_print(mm->end()); }));
auto p2_ins_out = migraphx::trim(capture_output([&] { mm->debug_print(one2); }));
EXPECT(program_out == ins_out);
EXPECT(inss_out == ins_out);
EXPECT(end_out == "End instruction");
EXPECT(p2_ins_out == "Instruction not part of program");
EXPECT(p2_ins_out == "Instruction not part of module");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -51,7 +51,7 @@ TEST_CASE(tanh_shape)
EXPECT(p1 == p2);
for(auto ins : iterator_for(p1))
for(auto ins : iterator_for(*p1.get_main_module()))
{
if(ins->name() == "hip::allocate")
{
......
......@@ -48,7 +48,7 @@ TEST_CASE(enable_fast_gelu)
{
migraphx::program p = create_gelu();
p.compile(migraphx::gpu::target{});
CHECK(any_of(p, [&](auto&& i) { return i.name() == "gpu::gelu"; }));
CHECK(any_of(*p.get_main_module(), [&](auto&& i) { return i.name() == "gpu::gelu"; }));
}
TEST_CASE(disable_fast_gelu)
......@@ -57,7 +57,7 @@ TEST_CASE(disable_fast_gelu)
migraphx::compile_options options;
options.fast_math = false;
p.compile(migraphx::gpu::target{}, options);
CHECK(any_of(p, [&](auto&& i) { return i.name() == "gpu::gelu_new"; }));
CHECK(any_of(*p.get_main_module(), [&](auto&& i) { return i.name() == "gpu::gelu_new"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -14,7 +14,7 @@ void gpu_literal_test()
mm->add_literal(lit);
p.compile(migraphx::gpu::target{});
auto scratch = p.get_parameter("scratch");
if(scratch == p.end())
if(scratch == mm->end())
{
auto result = p.eval({}).back();
EXPECT(lit == migraphx::gpu::from_gpu(result));
......
......@@ -56,7 +56,7 @@ TEST_CASE(match_name2)
mm->add_instruction(pass_op{}, sum);
auto m = match::name("min");
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_name3)
......@@ -98,7 +98,7 @@ TEST_CASE(match_arg2)
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::arg(0)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_arg3)
......@@ -140,7 +140,7 @@ TEST_CASE(match_arg5)
mm->add_instruction(pass_op{}, sum);
auto m = match::name("pass")(match::arg(1)(match::name("sum")), match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_arg6)
......@@ -257,7 +257,7 @@ TEST_CASE(match_args2)
auto m = match::name("sum")(match::args(match::name("@literal"), match::name("sum")),
match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_args3)
......@@ -271,7 +271,7 @@ TEST_CASE(match_args3)
mm->add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::args(match::name("@literal")), match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_args4)
......@@ -302,7 +302,7 @@ TEST_CASE(match_args5)
auto m = match::name("sum")(match::args(match::name("sum"), match::name("@literal")),
match::standard_shape());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_args6)
......@@ -380,7 +380,7 @@ TEST_CASE(match_either_args3)
auto m =
match::name("sum")(match::either_arg(0, 1)(match::name("pass"), match::name("@literal")));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_either_args_any1)
......@@ -495,7 +495,7 @@ TEST_CASE(match_all_of2)
auto m = match::name("sum")(
match::all_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_all_of3)
......@@ -534,7 +534,7 @@ TEST_CASE(match_lazy_all_of)
mm->add_instruction(pass_op{}, one);
auto m = match::all_of(match::none(), throws());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_lazy_none_of)
......@@ -546,7 +546,7 @@ TEST_CASE(match_lazy_none_of)
mm->add_instruction(pass_op{}, one);
auto m = match::none_of(match::any(), throws());
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_any_of1)
......@@ -576,7 +576,7 @@ TEST_CASE(match_any_of2)
auto m = match::name("sum")(
match::any_of(match::arg(0)(match::name("sum")), match::arg(1)(match::name("sum"))));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_any_of_lazy1)
......@@ -707,7 +707,7 @@ TEST_CASE(match_none_of2)
auto m = match::name("sum")(match::none_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_output1)
......@@ -737,7 +737,7 @@ TEST_CASE(match_output2)
mm->add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_skip_output1)
......@@ -818,7 +818,7 @@ TEST_CASE(match_skip_output5)
mm->add_instruction(pass_op{}, sum3);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_skip_output6)
......@@ -934,7 +934,7 @@ TEST_CASE(match_has_value4)
mm->add_instruction(pass_op{}, sum2);
auto m = match::has_value(3);
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_has_value5)
......@@ -949,7 +949,7 @@ TEST_CASE(match_has_value5)
mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(1), match::has_value(3)));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_has_value6)
......@@ -964,7 +964,7 @@ TEST_CASE(match_has_value6)
mm->add_instruction(pass_op{}, sum2);
auto m = match::name("sum")(match::args(match::has_value(2), match::has_value(1)));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_tree1)
......@@ -996,7 +996,7 @@ TEST_CASE(match_tree2)
mm->add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(2), match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_tree3)
......@@ -1029,7 +1029,7 @@ TEST_CASE(match_tree4)
auto m = match::tree(
"sum", match::has_value(1), match::has_value(2), match::has_value(3), match::has_value(4));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_tree5)
......@@ -1045,7 +1045,7 @@ TEST_CASE(match_tree5)
mm->add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(2), match::has_value(3));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_tree6)
......@@ -1061,7 +1061,7 @@ TEST_CASE(match_tree6)
mm->add_instruction(pass_op{}, sum2);
auto m = match::tree("sum", match::has_value(1), match::has_value(3));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
TEST_CASE(match_unordered_tree1)
......@@ -1129,7 +1129,7 @@ TEST_CASE(match_unordered_tree4)
auto m =
match::unordered_tree("sum", match::has_value(4), match::has_value(2), match::has_value(1));
auto r = find_match(*mm, m);
EXPECT(bool{r.result == p.end()});
EXPECT(bool{r.result == mm->end()});
}
struct match_find_sum
......
......@@ -42,7 +42,9 @@ migraphx::instruction_ref add_alloc(migraphx::program& p, const migraphx::shape&
bool no_allocate(const migraphx::program& p)
{
return std::none_of(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "allocate"; });
const auto* mm = p.get_main_module();
return std::none_of(
mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "allocate"; });
}
TEST_CASE(test1)
......
#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <sstream>
#include "test.hpp"
#include <migraphx/make_op.hpp>
#include <basic_ops.hpp>
migraphx::program create_program()
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int64_type});
auto y = mm->add_parameter("y", {migraphx::shape::int64_type});
auto sum = mm->add_instruction(sum_op{}, x, y);
auto one = mm->add_literal(1);
mm->add_instruction(sum_op{}, sum, one);
return p;
}
TEST_CASE(module_ins_clear)
{
migraphx::program p1 = create_program();
migraphx::program p2;
p2 = p1;
EXPECT(p1 == p2);
}
TEST_CASE(module_print_graph)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module();
std::stringstream ss1;
mm1->print_graph(ss1, true);
std::stringstream ss2;
mm2->print_graph(ss2, true);
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(module_print_cpp)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module();
std::stringstream ss1;
mm1->print_cpp(ss1);
std::stringstream ss2;
mm2->print_cpp(ss2);
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(module_annotate)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
auto* mm1 = p1.get_main_module();
auto* mm2 = p2.get_main_module();
EXPECT(*mm1 == *mm2);
std::stringstream ss1;
mm1->annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
std::stringstream ss2;
mm2->annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
EXPECT(ss1.str() == ss2.str());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -40,18 +40,20 @@ struct normalize_test_op
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::normalize_ops{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(*p.get_main_module(),
{migraphx::normalize_ops{}, migraphx::dead_code_elimination{}});
}
migraphx::program create_gather(int64_t axis)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}};
migraphx::shape si{migraphx::shape::int64_type, {2, 3}};
auto di = p.add_parameter("data", sd);
auto ii = p.add_parameter("ind", si);
auto r = p.add_instruction(migraphx::make_op("gather", {{"axis", axis}}), di, ii);
p.add_return({r});
auto di = mm->add_parameter("data", sd);
auto ii = mm->add_parameter("ind", si);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), di, ii);
mm->add_return({r});
return p;
}
......@@ -78,10 +80,11 @@ TEST_CASE(gather_test_1)
migraphx::program create_reduce_mean(const std::vector<int64_t>& axes)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto si = p.add_parameter("data", s);
auto r = p.add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), si);
p.add_return({r});
auto si = mm->add_parameter("data", s);
auto r = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), si);
mm->add_return({r});
return p;
}
......@@ -109,11 +112,12 @@ migraphx::program create_slice(const std::vector<int64_t>& axes,
const std::vector<int64_t>& ends)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto si = p.add_parameter("data", s);
auto r = p.add_instruction(
auto si = mm->add_parameter("data", s);
auto r = mm->add_instruction(
migraphx::make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), si);
p.add_return({r});
mm->add_return({r});
return p;
}
......@@ -139,10 +143,11 @@ TEST_CASE(slice_test_1)
migraphx::program create_test_op(const std::vector<int64_t>& axes)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {2, 3, 4}};
auto di = p.add_parameter("data", sd);
auto r = p.add_instruction(normalize_test_op{axes}, di);
p.add_return({r});
auto di = mm->add_parameter("data", sd);
auto r = mm->add_instruction(normalize_test_op{axes}, di);
mm->add_return({r});
return p;
}
......
......@@ -1031,17 +1031,18 @@ TEST_CASE(expand_test)
migraphx::program create_external_data_prog()
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s(migraphx::shape::float_type, {1, 1, 224, 224});
migraphx::shape s2(migraphx::shape::float_type, {10, 1, 11, 11});
std::vector<float> weight_data(1210, 1);
std::vector<float> bias_data(10, 1);
auto bias = p.add_literal(migraphx::literal({migraphx::shape::float_type, {10}}, bias_data));
auto weights = p.add_literal(migraphx::literal(s2, weight_data));
auto param = p.add_parameter("input", s);
auto conv = p.add_instruction(migraphx::make_op("convolution"), param, weights);
auto bias_bcast = p.add_instruction(
auto bias = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {10}}, bias_data));
auto weights = mm->add_literal(migraphx::literal(s2, weight_data));
auto param = mm->add_parameter("input", s);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), param, weights);
auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", {1, 10, 214, 214}}}), bias);
p.add_instruction(migraphx::make_op("add"), conv, bias_bcast);
mm->add_instruction(migraphx::make_op("add"), conv, bias_bcast);
return p;
}
......
......@@ -28,6 +28,8 @@ TEST_CASE(program_equality)
{
migraphx::program x = create_program();
migraphx::program y = create_program();
EXPECT(x.size() == 1);
EXPECT(x == y);
}
......@@ -56,6 +58,40 @@ TEST_CASE(program_default_copy_construct)
EXPECT(x == y);
}
TEST_CASE(program_print)
{
migraphx::program p = create_program();
auto* mm = p.get_main_module();
auto in1 = mm->end();
// print end instruction
p.debug_print(in1);
// print instruction not in the program
auto p2 = p;
auto* mm2 = p2.get_main_module();
auto in2 = mm2->begin();
p.debug_print(in2);
// print last instruction
auto in3 = std::prev(in1);
p.debug_print(in3);
}
TEST_CASE(program_annotate)
{
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
std::stringstream ss1;
p1.annotate(ss1, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
std::stringstream ss2;
p2.annotate(ss2, [](auto ins) { std::cout << ins->name() << "_1" << std::endl; });
EXPECT(ss1.str() == ss2.str());
}
TEST_CASE(program_copy)
{
auto create_program_1 = [] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment