Commit 2ed09f15 authored by Paul's avatar Paul
Browse files

Add cpu reduce_sum

parent 8be483c5
...@@ -79,6 +79,7 @@ struct literal : raw_data<literal> ...@@ -79,6 +79,7 @@ struct literal : raw_data<literal>
template <class Iterator> template <class Iterator>
void fill(Iterator start, Iterator end) void fill(Iterator start, Iterator end)
{ {
assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard()) if(m_shape.standard())
{ {
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); }); m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SUM_HPP
#define MIGRAPHX_GUARD_OPERATORS_SUM_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reduce_sum
{
std::vector<std::size_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reduce_sum"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis:axes)
lens[axis] = 1;
return {s.type(), lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input){
shape_for_each(input.get_shape(), [&](auto&& in_idx) {
auto out_idx = in_idx;
for(auto axis:axes)
out_idx[axis] = 0;
output(out_idx.begin(), out_idx.end()) += input(in_idx.begin(), in_idx.end());
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include <migraphx/op/outline.hpp> #include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp> #include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
......
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
......
...@@ -2,7 +2,17 @@ ...@@ -2,7 +2,17 @@
#include <migraphx/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
......
...@@ -1583,4 +1583,79 @@ TEST_CASE(clip_test) ...@@ -1583,4 +1583,79 @@ TEST_CASE(clip_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(reduce_sum_test0)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{0}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{15, 18, 21, 24};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{1}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{4, 6, 12, 14, 20, 22};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{3, 7, 11, 15, 19, 23};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{0, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{33, 45};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{10, 26, 42};
EXPECT(results_vector == gold);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment