Unverified Commit 1b692d0f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Change api to multiple prog outputs (only API change) (#433)



* Add initial api

* Formatting

* Add more api

* Formatting

* Add auto api generation

* Formatting

* Fix some compilation errors

* Change handle struct

* Formatting

* Fix reamining compilation errors

* Formatting

* Simplify using ctype

* Formatting

* Initial c++ generation

* Formatting

* Add C++header

* Formatting

* Add test

* Formatting

* Add initial tests

* Formatting

* Try to fix formatting

* Cleanup formatting

* Formatting

* Fix constructors on the same line

* Fix tests

* Formatting

* Fix tidy issues

* Fix tidy issues

* Fix naming issue

* Add onnx API to parse buffer

* Formatting

* Add arguments api

* Formatting

* Fix verify parameters

* Fix cppcheck issues

* Formatting

* Add method to get output shapes and bytes

* Formatting

* Try formatting

* Formatting

* Improve the test coverage

* Formatting

* Add print method

* Formatting

* Fix cppcheck issue

* Fix package dependency

* change migraphx api to support multiple program outputs

* clang format

* change api implementation

* clang format

* fix a build error

* change api for correct automatic generation

* clang format

* Add nolint

* Try fix formatting

* Formatting

* formatting

* formatting

* Fix formatting

* code cleanup

* clang format

* fix cppcheck error

* fix review comments

* clang format
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>
parent ba07b221
......@@ -14,7 +14,7 @@ TEST_CASE(instance_norm_test)
migraphx::program p = migraphx::parse_onnx("instance_norm_val_test.onnx");
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
auto result = p.eval({}).back();
std::vector<float> result_vector(9);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
......
......@@ -14,10 +14,10 @@ void expect_shape(const migraphx::shape& expected, const migraphx::operation& op
std::transform(
shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
p.add_instruction(op, args);
if(p.get_shape() != expected)
if(p.get_output_shapes().back() != expected)
{
std::cout << "FAILED: Incorrect shape for " << op.name() << ": ";
std::cout << expected << " != " << p.get_shape() << std::endl;
std::cout << expected << " != " << p.get_output_shapes().back() << std::endl;
for(auto&& s : shapes)
std::cout << " " << s << std::endl;
}
......
......@@ -82,8 +82,8 @@ def test_output():
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
p.compile(migraphx.get_target("gpu"))
r1 = run(p)
r2 = run(p)
r1 = run(p)[-1]
r2 = run(p)[-1]
assert_eq(r1, r2)
assert_eq(r1.tolist(), r2.tolist())
......
......@@ -2,11 +2,11 @@ import migraphx
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print(p)
s1 = p.get_shape()
s1 = p.get_output_shapes()[-1]
print("Compiling ...")
p.compile(migraphx.get_target("cpu"))
print(p)
s2 = p.get_shape()
s2 = p.get_output_shapes()[-1]
assert s1 == s2
params = {}
......@@ -14,5 +14,5 @@ for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.generate_argument(value)
r = p.run(params)
r = p.run(params)[-1]
print(r)
......@@ -11,5 +11,5 @@ for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
r = migraphx.from_gpu(p.run(params))
r = migraphx.from_gpu(p.run(params)[-1])
print(r)
......@@ -866,7 +866,7 @@ TEST_CASE(target_copy)
}
}
auto result = t.copy_from(p.eval(m));
auto result = t.copy_from(p.eval(m).back());
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
};
......@@ -923,7 +923,7 @@ TEST_CASE(int8_quantization_dot)
}
}
auto result = t.copy_from(p.eval(m));
auto result = t.copy_from(p.eval(m).back());
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
};
......@@ -972,7 +972,7 @@ TEST_CASE(int8_quantization_conv)
p.compile(t);
migraphx::program::parameter_map m;
auto result = t.copy_from(p.eval(m));
auto result = t.copy_from(p.eval(m).back());
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
};
......
......@@ -61,8 +61,8 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({});
auto result2 = p2.eval({});
auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back();
std::vector<float> results_vector1;
std::vector<float> results_vector2;
......@@ -129,8 +129,8 @@ TEST_CASE(as_literal)
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({});
auto result2 = p2.eval({});
auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back();
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
......@@ -167,8 +167,8 @@ TEST_CASE(literal_reshape)
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({});
auto result2 = p2.eval({});
auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back();
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
......
......@@ -271,9 +271,9 @@ TEST_CASE(simplify_add_conv1)
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_shape());
EXPECT(s == p.get_output_shapes().back());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}
......@@ -291,9 +291,9 @@ TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {3, 3}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_shape());
EXPECT(s == p.get_output_shapes().back());
// No fusion
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
......@@ -312,9 +312,9 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides1)
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_shape());
EXPECT(s == p.get_output_shapes().back());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}
......@@ -332,9 +332,9 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides2)
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_shape());
EXPECT(s == p.get_output_shapes().back());
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 1);
}
......@@ -352,9 +352,9 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_shape());
EXPECT(s == p.get_output_shapes().back());
// No fusion
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
......@@ -373,9 +373,9 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto s = p.get_shape();
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_shape());
EXPECT(s == p.get_output_shapes().back());
// No fusion
EXPECT(std::count_if(
p.begin(), p.end(), [](auto&& ins) { return ins.name() == "convolution"; }) == 2);
......
......@@ -20,13 +20,13 @@ TEST_CASE(double_contig)
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 4);
auto result = p.eval({});
auto result = p.eval({}).back();
EXPECT(result != get_2x2());
}
......@@ -37,13 +37,13 @@ TEST_CASE(double_transpose)
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
p.add_instruction(pass_op{}, t2);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
auto result = p.eval({}).back();
EXPECT(result == get_2x2());
}
......@@ -56,13 +56,13 @@ TEST_CASE(double_transpose_contig)
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
auto result = p.eval({}).back();
EXPECT(result == get_2x2());
}
......@@ -72,13 +72,13 @@ TEST_CASE(single_transpose)
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t1);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
run_pass(p);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 3);
auto result = p.eval({});
auto result = p.eval({}).back();
EXPECT(result != get_2x2());
}
......@@ -88,14 +88,14 @@ TEST_CASE(double_transpose_sin_pass)
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
// TODO: Fix this
// EXPECT(std::distance(p.begin(), p.end()) == 1);
auto result = p.eval({});
auto result = p.eval({}).back();
EXPECT(result == get_2x2());
}
......@@ -104,13 +104,13 @@ TEST_CASE(single_transpose_sin_pass)
migraphx::program p;
auto l = p.add_literal(get_2x2());
p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
run_pass(p);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
auto result = p.eval({}).back();
EXPECT(result != get_2x2());
}
......@@ -124,10 +124,10 @@ TEST_CASE(reshape_transpose)
auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
p.add_instruction(pass_op{}, r2);
EXPECT(p.get_shape() == s);
EXPECT(p.get_output_shapes().back() == s);
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape() == s);
EXPECT(p.get_output_shapes().back() == s);
EXPECT(std::distance(p.begin(), p.end()) == n);
}
......@@ -139,10 +139,10 @@ TEST_CASE(transpose_contiguous)
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_instruction(pass_op{}, c1);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n);
}
......@@ -155,10 +155,10 @@ TEST_CASE(transpose_double_contiguous)
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(p.has_instruction(t));
}
......@@ -171,10 +171,10 @@ TEST_CASE(transpose_partial1)
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
......@@ -187,10 +187,10 @@ TEST_CASE(transpose_partial2)
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
p.add_instruction(pass_op{}, t3);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}
......@@ -204,10 +204,10 @@ TEST_CASE(transpose_partial3)
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}
......@@ -218,10 +218,10 @@ TEST_CASE(nop_transpose1)
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
......@@ -235,10 +235,10 @@ TEST_CASE(nop_transpose2)
auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}
......@@ -252,10 +252,10 @@ TEST_CASE(nop_transpose3)
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape() == out_shape);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
......@@ -270,10 +270,10 @@ TEST_CASE(concat_transpose1)
auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
......@@ -292,10 +292,10 @@ TEST_CASE(concat_transpose2)
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
......@@ -314,10 +314,10 @@ TEST_CASE(concat_transpose3)
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
......@@ -335,10 +335,10 @@ TEST_CASE(nested_concat)
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2);
p.add_instruction(pass_op{}, concat3);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}
......@@ -355,10 +355,10 @@ TEST_CASE(nested_concat_partial)
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
p.add_instruction(pass_op{}, concat3);
auto out_shape = p.get_shape();
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(p.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
EXPECT(std::count_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; }) == 1);
}
......
......@@ -13,8 +13,8 @@ TEST_CASE(simple_test)
p.add_instruction(sum_op{}, one, two);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
EXPECT(result != migraphx::literal{4});
EXPECT(result.back() == migraphx::literal{3});
EXPECT(result.back() != migraphx::literal{4});
}
TEST_CASE(out_of_order)
......
......@@ -111,15 +111,10 @@ bool equal(const T& x, const T& y)
std::vector<argument> run(program& p, const program::parameter_map& params)
{
auto a = p.eval(params);
return {a};
return p.eval(params);
}
std::vector<shape> get_output_shapes(program& p)
{
auto a = p.get_shape();
return {a};
}
std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
void print(const program& p) { std::cout << p << std::endl; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment