You need to sign in or sign up before continuing.
Commit 108599df authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add test cases for reduce mean

parent 81952816
......@@ -14,7 +14,7 @@ namespace op {
struct reduce_mean
{
std::vector<int64_t> axes;
std::vector<int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -30,7 +30,12 @@ struct reduce_mean
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
{
if (axis < 0 or axis >= lens.size())
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
lens[axis] = 1;
}
return {s.type(), lens};
}
......
......@@ -14,7 +14,7 @@ namespace op {
struct reduce_sum
{
std::vector<std::size_t> axes;
std::vector<int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -30,7 +30,11 @@ struct reduce_sum
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
{
if (axis < 0 or axis >= lens.size())
MIGRAPHX_THROW("REDUCE_SUM: axis out of range");
lens[axis] = 1;
}
return {s.type(), lens};
}
......
......@@ -1295,13 +1295,13 @@ struct onnx_parser
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<std::size_t> axes(n_dim);
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
......@@ -1317,8 +1317,7 @@ struct onnx_parser
else
{
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
std::vector<int64_t> squeeze_axes{axes.begin(), axes.end()};
return prog.add_instruction(op::squeeze{squeeze_axes}, ins);
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
......
......@@ -1763,4 +1763,65 @@ TEST_CASE(reduce_sum_test12)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2, 3, 6, 7, 10, 11};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.5f, 3.5f, 5.5f, 7.5f, 9.5f, 11.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{0, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5.5, 7.5};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.5f, 6.5f, 10.5f};
EXPECT(results_vector == gold);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -390,37 +390,81 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
template <class T>
void test_argop_var()
TEST_CASE(test_argmax)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, migraphx::op::argmax{0}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, migraphx::op::argmax{1}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, migraphx::op::argmax{2}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, migraphx::op::argmax{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::op::argmax{4}, input);
}
}
TEST_CASE(test_argmin)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, T{0}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, migraphx::op::argmin{0}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, T{1}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, migraphx::op::argmin{1}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, T{2}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, migraphx::op::argmin{2}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, T{3}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, migraphx::op::argmin{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::op::argmin{4}, input);
}
}
template <class T>
void test_reduce_ops()
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{4}, input);
throws_shape(T{{4}}, input);
}
}
TEST_CASE(argmax) { test_argop_var<migraphx::op::argmax>(); }
TEST_CASE(argmin) { test_argop_var<migraphx::op::argmin>(); }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
// 2 inputs arguments
TEST_CASE(matmul)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment