Unverified Commit 1b692d0f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Change api to multiple prog outputs (only API change) (#433)



* Add initial api

* Formatting

* Add more api

* Formatting

* Add auto api generation

* Formatting

* Fix some compilation errors

* Change handle struct

* Formatting

* Fix reamining compilation errors

* Formatting

* Simplify using ctype

* Formatting

* Initial c++ generation

* Formatting

* Add C++header

* Formatting

* Add test

* Formatting

* Add initial tests

* Formatting

* Try to fix formatting

* Cleanup formatting

* Formatting

* Fix constructors on the same line

* Fix tests

* Formatting

* Fix tidy issues

* Fix tidy issues

* Fix naming issue

* Add onnx API to parse buffer

* Formatting

* Add arguments api

* Formatting

* Fix verify parameters

* Fix cppcheck issues

* Formatting

* Add method to get output shapes and bytes

* Formatting

* Try formatting

* Formatting

* Improve the test coverage

* Formatting

* Add print method

* Formatting

* Fix cppcheck issue

* Fix package dependency

* change migraphx api to support multiple program outputs

* clang format

* change api implementation

* clang format

* fix a build error

* change api for correct automatic generation

* clang format

* Add nolint

* Try fix formatting

* Formatting

* formatting

* formatting

* Fix formatting

* code cleanup

* clang format

* fix cppcheck error

* fix review comments

* clang format
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>
parent ba07b221
...@@ -14,7 +14,7 @@ TEST_CASE(instance_norm_test) ...@@ -14,7 +14,7 @@ TEST_CASE(instance_norm_test)
migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx"); migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx");
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({}).back();
std::vector<float> result_vector(9); std::vector<float> result_vector(9);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
......
...@@ -14,10 +14,10 @@ void expect_shape(const migraphx::shape& expected, const migraphx::operation& op ...@@ -14,10 +14,10 @@ void expect_shape(const migraphx::shape& expected, const migraphx::operation& op
std::transform( std::transform(
shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); }); shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
p.add_instruction(op, args); p.add_instruction(op, args);
if(p.get_shape() != expected) if(p.get_output_shapes().back() != expected)
{ {
std::cout << "FAILED: Incorrect shape for " << op.name() << ": "; std::cout << "FAILED: Incorrect shape for " << op.name() << ": ";
std::cout << expected << " != " << p.get_shape() << std::endl; std::cout << expected << " != " << p.get_output_shapes().back() << std::endl;
for(auto&& s : shapes) for(auto&& s : shapes)
std::cout << " " << s << std::endl; std::cout << " " << s << std::endl;
} }
......
...@@ -82,8 +82,8 @@ def test_output(): ...@@ -82,8 +82,8 @@ def test_output():
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx") p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
p.compile(migraphx.get_target("gpu")) p.compile(migraphx.get_target("gpu"))
r1 = run(p) r1 = run(p)[-1]
r2 = run(p) r2 = run(p)[-1]
assert_eq(r1, r2) assert_eq(r1, r2)
assert_eq(r1.tolist(), r2.tolist()) assert_eq(r1.tolist(), r2.tolist())
......
...@@ -2,11 +2,11 @@ import migraphx ...@@ -2,11 +2,11 @@ import migraphx
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx") p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print(p) print(p)
s1 = p.get_shape() s1 = p.get_output_shapes()[-1]
print("Compiling ...") print("Compiling ...")
p.compile(migraphx.get_target("cpu")) p.compile(migraphx.get_target("cpu"))
print(p) print(p)
s2 = p.get_shape() s2 = p.get_output_shapes()[-1]
assert s1 == s2 assert s1 == s2
params = {} params = {}
...@@ -14,5 +14,5 @@ for key, value in p.get_parameter_shapes().items(): ...@@ -14,5 +14,5 @@ for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value)) print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.generate_argument(value) params[key] = migraphx.generate_argument(value)
r = p.run(params) r = p.run(params)[-1]
print(r) print(r)
...@@ -11,5 +11,5 @@ for key, value in p.get_parameter_shapes().items(): ...@@ -11,5 +11,5 @@ for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value)) print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.to_gpu(migraphx.generate_argument(value)) params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
r = migraphx.from_gpu(p.run(params)) r = migraphx.from_gpu(p.run(params)[-1])
print(r) print(r)
...@@ -866,7 +866,7 @@ TEST_CASE(target_copy) ...@@ -866,7 +866,7 @@ TEST_CASE(target_copy)
} }
} }
auto result = t.copy_from(p.eval(m)); auto result = t.copy_from(p.eval(m).back());
result.visit([&](auto v) { res.assign(v.begin(), v.end()); }); result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
}; };
...@@ -923,7 +923,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -923,7 +923,7 @@ TEST_CASE(int8_quantization_dot)
} }
} }
auto result = t.copy_from(p.eval(m)); auto result = t.copy_from(p.eval(m).back());
result.visit([&](auto v) { res.assign(v.begin(), v.end()); }); result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
}; };
...@@ -972,7 +972,7 @@ TEST_CASE(int8_quantization_conv) ...@@ -972,7 +972,7 @@ TEST_CASE(int8_quantization_conv)
p.compile(t); p.compile(t);
migraphx::program::parameter_map m; migraphx::program::parameter_map m;
auto result = t.copy_from(p.eval(m)); auto result = t.copy_from(p.eval(m).back());
result.visit([&](auto v) { res.assign(v.begin(), v.end()); }); result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
}; };
......
...@@ -61,8 +61,8 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test) ...@@ -61,8 +61,8 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
p1.compile(migraphx::cpu::target{}); p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{}); p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({}); auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}); auto result2 = p2.eval({}).back();
std::vector<float> results_vector1; std::vector<float> results_vector1;
std::vector<float> results_vector2; std::vector<float> results_vector2;
...@@ -129,8 +129,8 @@ TEST_CASE(as_literal) ...@@ -129,8 +129,8 @@ TEST_CASE(as_literal)
p1.compile(migraphx::cpu::target{}); p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{}); p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({}); auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}); auto result2 = p2.eval({}).back();
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
} }
...@@ -167,8 +167,8 @@ TEST_CASE(literal_reshape) ...@@ -167,8 +167,8 @@ TEST_CASE(literal_reshape)
p1.compile(migraphx::cpu::target{}); p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{}); p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({}); auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}); auto result2 = p2.eval({}).back();
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
} }
......
...@@ -271,9 +271,9 @@ TEST_CASE(simplify_add_conv1) ...@@ -271,9 +271,9 @@ TEST_CASE(simplify_add_conv1)
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v); auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_output_shapes().back();
run_pass(p); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_output_shapes().back());
EXPECT(std::count_if( EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
} }
...@@ -291,9 +291,9 @@ TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides) ...@@ -291,9 +291,9 @@ TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {3, 3}}, y, v); auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {3, 3}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_output_shapes().back();
run_pass(p); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_output_shapes().back());
// No fusion // No fusion
EXPECT(std::count_if( EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2); p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
...@@ -312,9 +312,9 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides1) ...@@ -312,9 +312,9 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides1)
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v); auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_output_shapes().back();
run_pass(p); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_output_shapes().back());
EXPECT(std::count_if( EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
} }
...@@ -332,9 +332,9 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides2) ...@@ -332,9 +332,9 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides2)
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v); auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_output_shapes().back();
run_pass(p); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_output_shapes().back());
EXPECT(std::count_if( EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1); p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
} }
...@@ -352,9 +352,9 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1) ...@@ -352,9 +352,9 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v); auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_output_shapes().back();
run_pass(p); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_output_shapes().back());
// No fusion // No fusion
EXPECT(std::count_if( EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2); p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
...@@ -373,9 +373,9 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2) ...@@ -373,9 +373,9 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, y, v); auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2); auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto s = p.get_shape(); auto s = p.get_output_shapes().back();
run_pass(p); run_pass(p);
EXPECT(s == p.get_shape()); EXPECT(s == p.get_output_shapes().back());
// No fusion // No fusion
EXPECT(std::count_if( EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2); p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
......
...@@ -20,13 +20,13 @@ TEST_CASE(double_contig) ...@@ -20,13 +20,13 @@ TEST_CASE(double_contig)
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1); auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1); auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2); p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
EXPECT(p.get_shape().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 4); EXPECT(std::distance(p.begin(), p.end()) == 4);
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result != get_2x2()); EXPECT(result != get_2x2());
} }
...@@ -37,13 +37,13 @@ TEST_CASE(double_transpose) ...@@ -37,13 +37,13 @@ TEST_CASE(double_transpose)
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
p.add_instruction(pass_op{}, t2); p.add_instruction(pass_op{}, t2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
EXPECT(p.get_shape().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == get_2x2()); EXPECT(result == get_2x2());
} }
...@@ -56,13 +56,13 @@ TEST_CASE(double_transpose_contig) ...@@ -56,13 +56,13 @@ TEST_CASE(double_transpose_contig)
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2); auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
p.add_instruction(pass_op{}, c2); p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
EXPECT(p.get_shape().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == get_2x2()); EXPECT(result == get_2x2());
} }
...@@ -72,13 +72,13 @@ TEST_CASE(single_transpose) ...@@ -72,13 +72,13 @@ TEST_CASE(single_transpose)
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t1); p.add_instruction(pass_op{}, t1);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 3); EXPECT(std::distance(p.begin(), p.end()) == 3);
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result != get_2x2()); EXPECT(result != get_2x2());
} }
...@@ -88,14 +88,14 @@ TEST_CASE(double_transpose_sin_pass) ...@@ -88,14 +88,14 @@ TEST_CASE(double_transpose_sin_pass)
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(migraphx::op::transpose{{1, 0}}, t1); p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
EXPECT(p.get_shape().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
EXPECT(p.get_shape().standard()); EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_output_shapes().back().transposed());
// TODO: Fix this // TODO: Fix this
// EXPECT(std::distance(p.begin(), p.end()) == 1); // EXPECT(std::distance(p.begin(), p.end()) == 1);
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result == get_2x2()); EXPECT(result == get_2x2());
} }
...@@ -104,13 +104,13 @@ TEST_CASE(single_transpose_sin_pass) ...@@ -104,13 +104,13 @@ TEST_CASE(single_transpose_sin_pass)
migraphx::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
p.add_instruction(migraphx::op::transpose{{1, 0}}, l); p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_output_shapes().back().transposed());
run_pass(p); run_pass(p);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2); EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({}); auto result = p.eval({}).back();
EXPECT(result != get_2x2()); EXPECT(result != get_2x2());
} }
...@@ -124,10 +124,10 @@ TEST_CASE(reshape_transpose) ...@@ -124,10 +124,10 @@ TEST_CASE(reshape_transpose)
auto ct = p.add_instruction(migraphx::op::contiguous{}, t); auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct); auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
p.add_instruction(pass_op{}, r2); p.add_instruction(pass_op{}, r2);
EXPECT(p.get_shape() == s); EXPECT(p.get_output_shapes().back() == s);
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape() == s); EXPECT(p.get_output_shapes().back() == s);
EXPECT(std::distance(p.begin(), p.end()) == n); EXPECT(std::distance(p.begin(), p.end()) == n);
} }
...@@ -139,10 +139,10 @@ TEST_CASE(transpose_contiguous) ...@@ -139,10 +139,10 @@ TEST_CASE(transpose_contiguous)
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x); auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t); auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c1); p.add_instruction(pass_op{}, c1);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n); EXPECT(std::distance(p.begin(), p.end()) == n);
} }
...@@ -155,10 +155,10 @@ TEST_CASE(transpose_double_contiguous) ...@@ -155,10 +155,10 @@ TEST_CASE(transpose_double_contiguous)
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t); auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1); auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2); p.add_instruction(pass_op{}, c2);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1); EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(p.has_instruction(t)); EXPECT(p.has_instruction(t));
} }
...@@ -171,10 +171,10 @@ TEST_CASE(transpose_partial1) ...@@ -171,10 +171,10 @@ TEST_CASE(transpose_partial1)
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
p.add_instruction(pass_op{}, t2); p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1); EXPECT(std::distance(p.begin(), p.end()) == n - 1);
} }
...@@ -187,10 +187,10 @@ TEST_CASE(transpose_partial2) ...@@ -187,10 +187,10 @@ TEST_CASE(transpose_partial2)
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2); auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
p.add_instruction(pass_op{}, t3); p.add_instruction(pass_op{}, t3);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 2); EXPECT(std::distance(p.begin(), p.end()) == n - 2);
} }
...@@ -204,10 +204,10 @@ TEST_CASE(transpose_partial3) ...@@ -204,10 +204,10 @@ TEST_CASE(transpose_partial3)
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2); auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3); auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
p.add_instruction(pass_op{}, t4); p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 3); EXPECT(std::distance(p.begin(), p.end()) == n - 3);
} }
...@@ -218,10 +218,10 @@ TEST_CASE(nop_transpose1) ...@@ -218,10 +218,10 @@ TEST_CASE(nop_transpose1)
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x); auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1); EXPECT(std::distance(p.begin(), p.end()) == n - 1);
} }
...@@ -235,10 +235,10 @@ TEST_CASE(nop_transpose2) ...@@ -235,10 +235,10 @@ TEST_CASE(nop_transpose2)
auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2); auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3); auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
p.add_instruction(pass_op{}, t4); p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 4); EXPECT(std::distance(p.begin(), p.end()) == n - 4);
} }
...@@ -252,10 +252,10 @@ TEST_CASE(nop_transpose3) ...@@ -252,10 +252,10 @@ TEST_CASE(nop_transpose3)
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat); auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
p.add_instruction(pass_op{}, t2); p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape() == out_shape); EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1); EXPECT(std::distance(p.begin(), p.end()) == n - 1);
} }
...@@ -270,10 +270,10 @@ TEST_CASE(concat_transpose1) ...@@ -270,10 +270,10 @@ TEST_CASE(concat_transpose1)
auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt); auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat); auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens()); EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 3); EXPECT(std::distance(p.begin(), p.end()) == n - 3);
auto new_concat = auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
...@@ -292,10 +292,10 @@ TEST_CASE(concat_transpose2) ...@@ -292,10 +292,10 @@ TEST_CASE(concat_transpose2)
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt); auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat); auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens()); EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2); EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat = auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
...@@ -314,10 +314,10 @@ TEST_CASE(concat_transpose3) ...@@ -314,10 +314,10 @@ TEST_CASE(concat_transpose3)
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt); auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat); auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t); p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens()); EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2); EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat = auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
...@@ -335,10 +335,10 @@ TEST_CASE(nested_concat) ...@@ -335,10 +335,10 @@ TEST_CASE(nested_concat)
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x); auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2); auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2);
p.add_instruction(pass_op{}, concat3); p.add_instruction(pass_op{}, concat3);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens()); EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2); EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1); EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
} }
...@@ -355,10 +355,10 @@ TEST_CASE(nested_concat_partial) ...@@ -355,10 +355,10 @@ TEST_CASE(nested_concat_partial)
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x); auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l); auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
p.add_instruction(pass_op{}, concat3); p.add_instruction(pass_op{}, concat3);
auto out_shape = p.get_shape(); auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end()); auto n = std::distance(p.begin(), p.end());
run_pass(p); run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens()); EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2); EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1); EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
} }
......
...@@ -13,8 +13,8 @@ TEST_CASE(simple_test) ...@@ -13,8 +13,8 @@ TEST_CASE(simple_test)
p.add_instruction(sum_op{}, one, two); p.add_instruction(sum_op{}, one, two);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraphx::literal{3}); EXPECT(result.back() == migraphx::literal{3});
EXPECT(result != migraphx::literal{4}); EXPECT(result.back() != migraphx::literal{4});
} }
TEST_CASE(out_of_order) TEST_CASE(out_of_order)
......
...@@ -111,15 +111,10 @@ bool equal(const T& x, const T& y) ...@@ -111,15 +111,10 @@ bool equal(const T& x, const T& y)
std::vector<argument> run(program& p, const program::parameter_map& params) std::vector<argument> run(program& p, const program::parameter_map& params)
{ {
auto a = p.eval(params); return p.eval(params);
return {a};
} }
std::vector<shape> get_output_shapes(program& p) std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
{
auto a = p.get_shape();
return {a};
}
void print(const program& p) { std::cout << p << std::endl; } void print(const program& p) { std::cout << p << std::endl; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment