Commit f80ef189 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added test for concat

parent 71d850ae
...@@ -336,7 +336,6 @@ struct cpu_concat ...@@ -336,7 +336,6 @@ struct cpu_concat
for(std::size_t l = 0; l < args.size(); l++) for(std::size_t l = 0; l < args.size(); l++)
{ {
auto argl = args[l]; auto argl = args[l];
std::cout << argl << std::endl;
std::size_t nelements = argl.get_shape().elements(); std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) { visit_all(result, argl)([&](auto output, auto input) {
auto* outptr = output.data() + coffsets[l]; auto* outptr = output.data() + coffsets[l];
......
...@@ -49,18 +49,51 @@ void slice_test() ...@@ -49,18 +49,51 @@ void slice_test()
void concat_test() void concat_test()
{ {
migraph::program p; {
std::size_t axis = 1; migraph::program p;
std::vector<int> data0 = {0, 1, 5, 6}; std::size_t axis = 1;
std::vector<int> data1 = {2, 3, 4, 5, 6, 7}; std::vector<int> data0 = {0, 1, 5, 6};
migraph::shape s0{migraph::shape::int32_type, {2, 2}}; std::vector<int> data1 = {2, 3, 4, 7, 8, 9};
migraph::shape s1{migraph::shape::int32_type, {2, 3}}; std::vector<int> data2 = {10, 20};
auto l0 = p.add_literal(migraph::literal{s0, data0}); migraph::shape s0{migraph::shape::int32_type, {2, 2}};
auto l1 = p.add_literal(migraph::literal{s1, data1}); migraph::shape s1{migraph::shape::int32_type, {2, 3}};
p.add_instruction(migraph::op::concat{axis}, l0, l1); migraph::shape s2{migraph::shape::int32_type, {2, 1}};
p.compile(migraph::cpu::cpu_target{}); auto l0 = p.add_literal(migraph::literal{s0, data0});
auto result = p.eval({}); auto l1 = p.add_literal(migraph::literal{s1, data1});
std::cout << result << std::endl; auto l2 = p.add_literal(migraph::literal{s2, data2});
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
std::vector<int> gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20};
std::vector<int> results_vector(2*6);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(migraph::verify_range(result.get_shape().lens(), std::vector<std::size_t>({2, 6})));
EXPECT(migraph::verify_range(result.get_shape().strides(), std::vector<std::size_t>({6, 1})));
}
{
migraph::program p;
std::size_t axis = 0;
std::vector<int> data0 = {0, 1, 2, 3};
std::vector<int> data1 = {4, 5, 6, 7, 8, 9};
std::vector<int> data2 = {10, 11};
migraph::shape s0{migraph::shape::int32_type, {2, 2}};
migraph::shape s1{migraph::shape::int32_type, {3, 2}};
migraph::shape s2{migraph::shape::int32_type, {1, 2}};
auto l0 = p.add_literal(migraph::literal{s0, data0});
auto l1 = p.add_literal(migraph::literal{s1, data1});
auto l2 = p.add_literal(migraph::literal{s2, data2});
p.add_instruction(migraph::op::concat{axis}, l0, l1, l2);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
std::vector<int> results_vector(6*2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(migraph::verify_range(result.get_shape().lens(), std::vector<std::size_t>({6, 2})));
EXPECT(migraph::verify_range(result.get_shape().strides(), std::vector<std::size_t>({2, 1})));
}
} }
void squeeze_test() void squeeze_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment