Commit 89375383 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add gpu tests for reduce_mean

parent 766427c9
...@@ -30,11 +30,11 @@ struct id ...@@ -30,11 +30,11 @@ struct id
struct scale struct scale
{ {
float factor = 1.0f; size_t item_num = 1;
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{ {
return static_cast<T>(x * factor); return static_cast<T>(x / item_num);
} }
}; };
......
...@@ -8,9 +8,8 @@ namespace device { ...@@ -8,9 +8,8 @@ namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg) void reduce_mean(hipStream_t stream, const argument& result, const argument& arg)
{ {
std::size_t batch_item_num = arg.get_shape().elements() / result.get_shape().elements(); std::size_t item_num = arg.get_shape().elements() / result.get_shape().elements();
float factor = 1.0f / batch_item_num; reduce(stream, result, arg, sum{}, 0, id{}, scale{item_num});
reduce(stream, result, arg, sum{}, 0, id{}, scale{factor});
} }
} // namespace device } // namespace device
......
...@@ -1823,4 +1823,19 @@ TEST_CASE(reduce_mean_test12) ...@@ -1823,4 +1823,19 @@ TEST_CASE(reduce_mean_test12)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_mean_int)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int> gold{2, 6, 10};
EXPECT(results_vector == gold);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -3499,4 +3499,41 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half> ...@@ -3499,4 +3499,41 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
}; };
}; };
struct test_reduce_mean : verify_program<test_reduce_mean>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 9, 4, 3}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_int : verify_program<test_reduce_mean_int>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment