Unverified Commit c9b86f1c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Module impl (#678)



* add an api get_main_module

* clang format

* modify onnx unit test for module

* clang format

* refactor ops unit test with the get_main_module

* clang format

* code backup

* clang format

* refine module c api

* add python api for module

* clang format

* fix a python api issue

* clang format

* fix cppcheck error

* clang format

* refine unit tests changes

* clang format

* code backup

* code backup

* clang format

* defer some changes to later PRs

* change return of get_main_module from ref to pointer

* clang format

* add unit tests for the get_main_module_api

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* add more unit tests for more code change coverage

* clang format

* fixed a unit test error

* clang format

* fix unit test

* clang format

* code backup

* code change for more code coverage

* change program to module in various passes and matcher

* clang format

* modify the pass API

* code backup

* code backup

* clang format

* code backup

* clang format

* Add option to no generate a destroy method

* Formatting

* fix some review comments

* clang format

* fix review comments

* clang format

* clang format

* code backup

* code backup

* clang format

* fix cppcheck errors

* clang format

* clang format

* fix build errors

* clang format

* modify gpu unit tests to using module

* clang format

* fix cppcheck error

* clang format

* Add flag to enable cpu backend

* Make buffers shared

* Enable optimizations

* Formatting

* fix review comments

* code backup

* clang format

* code backup

* clang format

* fix a bug related to a unit test

* clang format

* clang format

* fix a build error

* remove unnecessary code

* remove unnecessary files

* code backup

* clang format

* remove the compile function from the module class

* clang format

* clang format

* remove the context parameter from the from_value method of the module class

* code refinement

* clang format

* merge changes from develop branch

* clang format

* fix cppcheck error

* clang format

* fix a build error

* fixed a merge error

* fix cppcheck error

* fixed review comments

* clang format

* fix cppcheck error

* fix a cppcheck error

* fix cppcheck error

* fix build error caused by merge

* Add missing has_op function

* Formatting

* merge changes from develop branch

* fix a cppcheck error

* fixed some review comments

* clang format

* remove the begin/end function of the program class

* clang format

* refine code and fix cppcheck error

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* add unit tests for more code coverage

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix a build error in debug mode

* clang format
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
parent 1dd4e4d9
......@@ -108,8 +108,8 @@ TEST_CASE(non_literal)
migraphx::rewrite_batchnorm opt;
opt.apply(*p2.get_main_module());
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
EXPECT(any_of(*p1.get_main_module(), &is_batch_norm));
EXPECT(none_of(*p2.get_main_module(), &is_batch_norm));
}
TEST_CASE(as_literal)
......@@ -137,8 +137,8 @@ TEST_CASE(as_literal)
migraphx::program p2 = create_program();
migraphx::rewrite_batchnorm opt;
opt.apply(*p2.get_main_module());
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
EXPECT(any_of(*p1.get_main_module(), &is_batch_norm));
EXPECT(none_of(*p2.get_main_module(), &is_batch_norm));
p1.compile(migraphx::ref::target{});
p2.compile(migraphx::ref::target{});
......@@ -176,8 +176,8 @@ TEST_CASE(as_literal_1d)
migraphx::program p2 = create_program();
migraphx::rewrite_batchnorm opt;
opt.apply(*p2.get_main_module());
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
EXPECT(any_of(*p1.get_main_module(), &is_batch_norm));
EXPECT(none_of(*p2.get_main_module(), &is_batch_norm));
p1.compile(migraphx::ref::target{});
p2.compile(migraphx::ref::target{});
......@@ -216,8 +216,8 @@ TEST_CASE(as_literal_3d)
migraphx::program p2 = create_program();
migraphx::rewrite_batchnorm opt;
opt.apply(*p2.get_main_module());
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
EXPECT(any_of(*p1.get_main_module(), &is_batch_norm));
EXPECT(none_of(*p2.get_main_module(), &is_batch_norm));
p1.compile(migraphx::ref::target{});
p2.compile(migraphx::ref::target{});
......@@ -252,8 +252,8 @@ TEST_CASE(literal_reshape)
migraphx::program p2 = create_program();
migraphx::rewrite_batchnorm opt;
opt.apply(*p2.get_main_module());
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
EXPECT(any_of(*p1.get_main_module(), &is_batch_norm));
EXPECT(none_of(*p2.get_main_module(), &is_batch_norm));
p1.compile(migraphx::ref::target{});
p2.compile(migraphx::ref::target{});
......@@ -303,8 +303,8 @@ TEST_CASE(literal_reshape_per_actv)
migraphx::program p2 = create_program();
migraphx::rewrite_batchnorm opt;
opt.apply(*p2.get_main_module());
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
EXPECT(any_of(*p1.get_main_module(), &is_batch_norm));
EXPECT(none_of(*p2.get_main_module(), &is_batch_norm));
p1.compile(migraphx::ref::target{});
p2.compile(migraphx::ref::target{});
......
......@@ -144,7 +144,8 @@ struct schedule_model_test
bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
{
for(auto ins : migraphx::iterator_for(p))
auto* mm = p.get_main_module();
for(auto ins : migraphx::iterator_for(*mm))
{
if(ins->name() != "identity")
continue;
......
......@@ -222,8 +222,8 @@ TEST_CASE(simplify_mul_conv1)
mm->add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
run_pass(p);
auto new_conv =
std::find_if(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; });
auto new_conv = std::find_if(
mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "convolution"; });
EXPECT(new_conv->outputs().front()->name() != "mul");
}
......@@ -399,8 +399,9 @@ TEST_CASE(simplify_add_conv1)
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
return ins.name() == "convolution";
}) == 1);
}
TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
......@@ -422,8 +423,9 @@ TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
// No fusion
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
return ins.name() == "convolution";
}) == 2);
}
TEST_CASE(simplify_add_conv_1x1_diff_strides1)
......@@ -444,8 +446,9 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides1)
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
return ins.name() == "convolution";
}) == 1);
}
TEST_CASE(simplify_add_conv_1x1_diff_strides2)
......@@ -466,8 +469,9 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides2)
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
return ins.name() == "convolution";
}) == 1);
}
TEST_CASE(simplify_add_conv_1x1_diff_strides_odd)
......@@ -488,8 +492,9 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides_odd)
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
return ins.name() == "convolution";
}) == 1);
}
TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
......@@ -511,8 +516,9 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
// No fusion
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
return ins.name() == "convolution";
}) == 2);
}
TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
......@@ -534,8 +540,9 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
// No fusion
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
EXPECT(std::count_if(mm->begin(), mm->end(), [](auto&& ins) {
return ins.name() == "convolution";
}) == 2);
}
TEST_CASE(simplify_concat_add_relu)
......
......@@ -32,7 +32,7 @@ TEST_CASE(double_contig)
run_pass(p);
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 4);
EXPECT(std::distance(mm->begin(), mm->end()) == 4);
auto result = p.eval({}).back();
EXPECT(result != get_2x2());
}
......@@ -51,7 +51,7 @@ TEST_CASE(double_transpose)
run_pass(p);
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
auto result = p.eval({}).back();
EXPECT(result == get_2x2());
}
......@@ -72,7 +72,7 @@ TEST_CASE(double_transpose_contig)
run_pass(p);
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
auto result = p.eval({}).back();
EXPECT(result == get_2x2());
}
......@@ -90,7 +90,7 @@ TEST_CASE(single_transpose)
run_pass(p);
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 3);
EXPECT(std::distance(mm->begin(), mm->end()) == 3);
auto result = p.eval({}).back();
EXPECT(result != get_2x2());
}
......@@ -109,7 +109,7 @@ TEST_CASE(double_transpose_sin_pass)
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
// TODO: Fix this
// EXPECT(std::distance(p.begin(), p.end()) == 1);
// EXPECT(std::distance(mm->begin(), mm->end()) == 1);
auto result = p.eval({}).back();
EXPECT(result == get_2x2());
}
......@@ -126,7 +126,7 @@ TEST_CASE(single_transpose_sin_pass)
run_pass(p);
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
EXPECT(std::distance(mm->begin(), mm->end()) == 2);
auto result = p.eval({}).back();
EXPECT(result != get_2x2());
}
......@@ -144,10 +144,10 @@ TEST_CASE(reshape_transpose)
auto r2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 112, 56, 56}}}), ct);
mm->add_return({r2});
EXPECT(p.get_output_shapes().back() == s);
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == s);
EXPECT(std::distance(p.begin(), p.end()) == n);
EXPECT(std::distance(mm->begin(), mm->end()) == n);
}
TEST_CASE(transpose_contiguous)
......@@ -161,10 +161,10 @@ TEST_CASE(transpose_contiguous)
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), t);
mm->add_return({c1});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n);
EXPECT(std::distance(mm->begin(), mm->end()) == n);
}
TEST_CASE(transpose_double_contiguous)
......@@ -179,10 +179,10 @@ TEST_CASE(transpose_double_contiguous)
auto c2 = mm->add_instruction(migraphx::make_op("contiguous"), c1);
mm->add_return({c2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
EXPECT(mm->has_instruction(t));
}
......@@ -197,10 +197,10 @@ TEST_CASE(transpose_partial1)
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 2, 0}}}), t1);
mm->add_return({t2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
}
TEST_CASE(transpose_partial2)
......@@ -215,10 +215,10 @@ TEST_CASE(transpose_partial2)
auto t3 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t2);
mm->add_return({t3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 2);
}
TEST_CASE(transpose_partial3)
......@@ -234,10 +234,10 @@ TEST_CASE(transpose_partial3)
auto t4 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0, 2}}}), t3);
mm->add_return({t4});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 3);
}
TEST_CASE(nop_transpose1)
......@@ -250,10 +250,10 @@ TEST_CASE(nop_transpose1)
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), x);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
}
TEST_CASE(nop_transpose2)
......@@ -269,10 +269,10 @@ TEST_CASE(nop_transpose2)
auto t4 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 2}}}), t3);
mm->add_instruction(pass_op{}, t4);
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 4);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 4);
}
TEST_CASE(nop_transpose3)
......@@ -288,10 +288,10 @@ TEST_CASE(nop_transpose3)
auto t2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t1);
mm->add_return({t2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
}
TEST_CASE(nop_convert)
......@@ -307,10 +307,10 @@ TEST_CASE(nop_convert)
x);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
}
TEST_CASE(concat_transpose1)
......@@ -327,13 +327,13 @@ TEST_CASE(concat_transpose1)
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), concat);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 3);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
std::find_if(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != mm->end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}
......@@ -351,13 +351,13 @@ TEST_CASE(concat_transpose2)
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
std::find_if(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != mm->end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
......@@ -375,13 +375,13 @@ TEST_CASE(concat_transpose3)
auto t = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 2, 3, 1}}}), concat);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
std::find_if(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != mm->end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
......@@ -419,11 +419,12 @@ TEST_CASE(nested_concat)
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2);
mm->add_return({concat3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 2);
EXPECT(std::count_if(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "concat"; }) ==
1);
}
TEST_CASE(nested_concat_partial)
......@@ -442,11 +443,12 @@ TEST_CASE(nested_concat_partial)
mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), concat1, concat2, l);
mm->add_return({concat3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 2);
EXPECT(std::count_if(mm->begin(), mm->end(), [](auto ins) { return ins.name() == "concat"; }) ==
1);
}
TEST_CASE(multibroadcast_simplify)
......@@ -459,9 +461,9 @@ TEST_CASE(multibroadcast_simplify)
auto x = mm->add_parameter("x", s);
auto y = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", s_lens}}), x);
mm->add_instruction(migraphx::make_op("mul"), y, y);
auto n = std::distance(p.begin(), p.end());
auto n = std::distance(mm->begin(), mm->end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(std::distance(mm->begin(), mm->end()) == n - 1);
}
TEST_CASE(double_slice1)
......
......@@ -52,9 +52,10 @@ TEST_CASE(add_test)
TEST_CASE(addv2_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::make_op("add"), l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::make_op("add"), l0, l1);
auto prog = optimize_tf("addv2_test.pb", false);
EXPECT(p == prog);
......@@ -175,18 +176,19 @@ TEST_CASE(batchnormv3_test)
float momentum = 0.9f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::op::batch_norm_inference op{
epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
migraphx::shape s0{migraphx::shape::float_type, {32}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
std::vector<float> const_vals(32);
std::fill(const_vals.begin(), const_vals.end(), 1.0f);
auto l2 = p.add_parameter("2", s0);
auto l3 = p.add_parameter("3", s0);
auto l4 = p.add_parameter("4", s0);
auto l1 = p.add_literal(migraphx::literal{s0, const_vals});
p.add_instruction(op, l0, l1, l2, l3, l4);
auto l2 = mm->add_parameter("2", s0);
auto l3 = mm->add_parameter("3", s0);
auto l4 = mm->add_parameter("4", s0);
auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
mm->add_instruction(op, l0, l1, l2, l3, l4);
auto prog = optimize_tf("batchnormv3_test.pb", true);
EXPECT(p == prog);
......
......@@ -11,7 +11,7 @@ TEST_CASE(simple_test)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
EXPECT(bool{p.validate() == p.end()});
EXPECT(bool{mm->validate() == mm->end()});
auto result = p.eval({});
EXPECT(result.back() == migraphx::literal{3});
EXPECT(result.back() != migraphx::literal{4});
......@@ -24,7 +24,7 @@ TEST_CASE(out_of_order)
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto ins = mm->add_instruction(sum_op{}, one, two);
mm->move_instruction(two, p.end());
mm->move_instruction(two, mm->end());
EXPECT(bool{p.validate() == ins});
}
......@@ -52,7 +52,7 @@ TEST_CASE(invalid_args)
auto two = mm->add_literal(2);
auto ins = mm->add_instruction(sum_op{}, one, two);
access_ins_arguments(*ins).clear();
EXPECT(bool{p.validate() == p.begin()});
EXPECT(bool{mm->validate() == mm->begin()});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -9,12 +9,13 @@ struct gemm_literal : verify_program<gemm_literal>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {2, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 4}};
auto a = p.add_literal(migraphx::generate_literal(a_shape));
auto b = p.add_parameter("b", b_shape);
p.add_instruction(migraphx::op::dot{}, a, b);
auto a = mm->add_literal(migraphx::generate_literal(a_shape));
auto b = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::op::dot{}, a, b);
return p;
}
......
......@@ -15,9 +15,9 @@ inline void check_gpu_streams(const migraphx::program& p)
{
std::cout << "FAILED: " << std::endl;
std::cout << "Race condition detected for: ";
p.debug_print(race.ins);
mm->debug_print(race.ins);
std::cout << "Should happen after: ";
p.debug_print(race.before);
mm->debug_print(race.before);
}
#else
(void)p;
......
......@@ -13,7 +13,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
#ifdef DOXYGEN
......
......@@ -15,8 +15,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
struct operation;
#ifdef DOXYGEN
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment