Commit d0c53a98 authored by Paul's avatar Paul
Browse files

Formatting

parent 2ed09f15
...@@ -26,9 +26,9 @@ struct reduce_sum ...@@ -26,9 +26,9 @@ struct reduce_sum
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
auto lens = s.lens(); auto lens = s.lens();
for(auto axis:axes) for(auto axis : axes)
lens[axis] = 1; lens[axis] = 1;
return {s.type(), lens}; return {s.type(), lens};
} }
...@@ -36,10 +36,10 @@ struct reduce_sum ...@@ -36,10 +36,10 @@ struct reduce_sum
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input){ visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(input.get_shape(), [&](auto&& in_idx) { shape_for_each(input.get_shape(), [&](auto&& in_idx) {
auto out_idx = in_idx; auto out_idx = in_idx;
for(auto axis:axes) for(auto axis : axes)
out_idx[axis] = 0; out_idx[axis] = 0;
output(out_idx.begin(), out_idx.end()) += input(in_idx.begin(), in_idx.end()); output(out_idx.begin(), out_idx.end()) += input(in_idx.begin(), in_idx.end());
}); });
......
...@@ -1588,7 +1588,7 @@ TEST_CASE(reduce_sum_test0) ...@@ -1588,7 +1588,7 @@ TEST_CASE(reduce_sum_test0)
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input); auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{0}}, l0); p.add_instruction(migraphx::op::reduce_sum{{0}}, l0);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -1603,7 +1603,7 @@ TEST_CASE(reduce_sum_test1) ...@@ -1603,7 +1603,7 @@ TEST_CASE(reduce_sum_test1)
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input); auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{1}}, l0); p.add_instruction(migraphx::op::reduce_sum{{1}}, l0);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -1618,7 +1618,7 @@ TEST_CASE(reduce_sum_test2) ...@@ -1618,7 +1618,7 @@ TEST_CASE(reduce_sum_test2)
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input); auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{2}}, l0); p.add_instruction(migraphx::op::reduce_sum{{2}}, l0);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -1633,7 +1633,7 @@ TEST_CASE(reduce_sum_test02) ...@@ -1633,7 +1633,7 @@ TEST_CASE(reduce_sum_test02)
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input); auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{0, 2}}, l0); p.add_instruction(migraphx::op::reduce_sum{{0, 2}}, l0);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -1648,7 +1648,7 @@ TEST_CASE(reduce_sum_test12) ...@@ -1648,7 +1648,7 @@ TEST_CASE(reduce_sum_test12)
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}; auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input); auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{1, 2}}, l0); p.add_instruction(migraphx::op::reduce_sum{{1, 2}}, l0);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment