Commit 85995696 authored by Paul's avatar Paul
Browse files

Use reduce_sum to reduce errors in reduce_mean

parent a824f63e
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <int N, migraphx::shape::type_t T> template <int N, migraphx::shape::type_t T>
...@@ -36,7 +37,9 @@ struct test_block_reduce_small : verify_program<test_block_reduce_small<N, T>> ...@@ -36,7 +37,9 @@ struct test_block_reduce_small : verify_program<test_block_reduce_small<N, T>>
migraphx::shape s{T, {2, N}}; migraphx::shape s{T, {2, N}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("x", s); auto y = mm->add_parameter("x", s);
auto r = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), x); auto two = mm->add_literal(migraphx::literal{migraphx::shape{s.type(), {1}}, {2}});
auto mul = migraphx::add_common_op(*mm, migraphx::make_op("mul"), {x, two});
auto r = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), mul);
auto rb = auto rb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), r); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), r);
auto add = mm->add_instruction(migraphx::make_op("add"), rb, y); auto add = mm->add_instruction(migraphx::make_op("add"), rb, y);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment