Commit b7c6b8ca authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge reduce mean to test_bert branch

parents 4d358059 a4055723
......@@ -612,7 +612,7 @@ struct test_softmax : verify_program<test_softmax<Axis, T>>
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 4, 1026, 6}};
migraphx::shape s{T, {512, 4, 1067, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param);
......@@ -3513,4 +3513,40 @@ struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
};
};
struct test_reduce_mean : verify_program<test_reduce_mean>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 9, 4, 3}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_int : verify_program<test_reduce_mean_int>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -848,6 +848,27 @@ TEST_CASE(reducesum_test3)
EXPECT(p == prog);
}
TEST_CASE(reducemean_test1)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = migraphx::parse_onnx("reducemean_test1.onnx");
EXPECT(p == prog);
}
TEST_CASE(reducemean_test2)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
auto prog = migraphx::parse_onnx("reducemean_test2.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test)
{
migraphx::program p;
......
reducemean-example:}
0
xy"
ReduceMean*
axes@*
keepdimstest_reducemeanZ
x




b
y




B
......@@ -390,37 +390,97 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
template <class T>
void test_argop_var()
TEST_CASE(test_argmax)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::op::argmax{0},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::op::argmax{1},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::op::argmax{2},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::op::argmax{3},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::op::argmax{4}, input);
}
}
TEST_CASE(test_argmin)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, T{0}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::op::argmin{0},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, T{1}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::op::argmin{1},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, T{2}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::op::argmin{2},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, T{3}, input);
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::op::argmin{3},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::op::argmin{4}, input);
}
}
template <class T>
void test_reduce_ops()
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{4}, input);
throws_shape(T{{4}}, input);
}
}
TEST_CASE(argmax) { test_argop_var<migraphx::op::argmax>(); }
TEST_CASE(argmin) { test_argop_var<migraphx::op::argmin>(); }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
// 2 inputs arguments
TEST_CASE(matmul)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment