Commit 9cbd1ec1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add more tests for softmax and logsoftmax

parent e59b3101
...@@ -586,7 +586,7 @@ struct test_softmax2 : verify_program<test_softmax2> ...@@ -586,7 +586,7 @@ struct test_softmax2 : verify_program<test_softmax2>
{ {
migraphx::program p; migraphx::program p;
auto x = auto x =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1028, 1, 25}}); p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1000, 1, 1}});
p.add_instruction(migraphx::op::softmax{}, x); p.add_instruction(migraphx::op::softmax{}, x);
return p; return p;
} }
...@@ -598,7 +598,7 @@ struct test_softmax : verify_program<test_softmax<Axis, T>> ...@@ -598,7 +598,7 @@ struct test_softmax : verify_program<test_softmax<Axis, T>>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{T, {2080, 4, 1026, 6}}; migraphx::shape s{T, {512, 4, 1067, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param); p.add_instruction(migraphx::op::softmax{Axis}, param);
...@@ -607,13 +607,14 @@ struct test_softmax : verify_program<test_softmax<Axis, T>> ...@@ -607,13 +607,14 @@ struct test_softmax : verify_program<test_softmax<Axis, T>>
}; };
template struct test_softmax<0, migraphx::shape::float_type>; template struct test_softmax<0, migraphx::shape::float_type>;
template struct test_softmax<1, migraphx::shape::float_type>;
template struct test_softmax<2, migraphx::shape::float_type>; template struct test_softmax<2, migraphx::shape::float_type>;
template struct test_softmax<3, migraphx::shape::float_type>;
template struct test_softmax<1, migraphx::shape::double_type>; template struct test_softmax<1, migraphx::shape::double_type>;
template struct test_softmax<3, migraphx::shape::double_type>; template struct test_softmax<3, migraphx::shape::double_type>;
// template struct test_softmax<0, migraphx::shape::half_type>; template struct test_softmax<0, migraphx::shape::half_type>;
// template struct test_softmax<2, migraphx::shape::half_type>; template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>;
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
...@@ -3349,12 +3350,12 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul ...@@ -3349,12 +3350,12 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul
}; };
template <int Axis> template <int Axis>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis>> struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1025, 4, 1025, 6}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
...@@ -3362,18 +3363,15 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis>> ...@@ -3362,18 +3363,15 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis>>
} }
}; };
template struct test_logsoftmax<0>; template struct test_logsoftmax_1<0>;
template struct test_logsoftmax<1>;
template struct test_logsoftmax<2>;
template struct test_logsoftmax<3>;
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>> struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
...@@ -3381,7 +3379,17 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>> ...@@ -3381,7 +3379,17 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
} }
}; };
template struct test_logsoftmax_1<0>; template struct test_logsoftmax<0, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::float_type>;
template struct test_logsoftmax<2, migraphx::shape::float_type>;
template struct test_logsoftmax<3, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::double_type>;
template struct test_logsoftmax<3, migraphx::shape::double_type>;
template struct test_logsoftmax<1, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::half_type>;
template struct test_logsoftmax<2, migraphx::shape::half_type>;
template struct test_logsoftmax<3, migraphx::shape::half_type>;
struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall> struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment