Commit 89375383 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add gpu tests for reduce_mean

parent 766427c9
......@@ -30,11 +30,11 @@ struct id
struct scale
{
float factor = 1.0f;
size_t item_num = 1;
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return static_cast<T>(x * factor);
return static_cast<T>(x / item_num);
}
};
......
......@@ -8,9 +8,8 @@ namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg)
{
std::size_t batch_item_num = arg.get_shape().elements() / result.get_shape().elements();
float factor = 1.0f / batch_item_num;
reduce(stream, result, arg, sum{}, 0, id{}, scale{factor});
std::size_t item_num = arg.get_shape().elements() / result.get_shape().elements();
reduce(stream, result, arg, sum{}, 0, id{}, scale{item_num});
}
} // namespace device
......
......@@ -1823,4 +1823,19 @@ TEST_CASE(reduce_mean_test12)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_int)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int> gold{2, 6, 10};
EXPECT(results_vector == gold);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -3499,4 +3499,41 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
};
};
struct test_reduce_mean : verify_program<test_reduce_mean>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 9, 4, 3}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_int : verify_program<test_reduce_mean_int>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment