Commit 9cbd1ec1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add more tests for softmax and logsoftmax

parent e59b3101
......@@ -586,7 +586,7 @@ struct test_softmax2 : verify_program<test_softmax2>
{
migraphx::program p;
auto x =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1028, 1, 25}});
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1000, 1, 1}});
p.add_instruction(migraphx::op::softmax{}, x);
return p;
}
......@@ -598,7 +598,7 @@ struct test_softmax : verify_program<test_softmax<Axis, T>>
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{T, {2080, 4, 1026, 6}};
migraphx::shape s{T, {512, 4, 1067, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param);
......@@ -607,13 +607,14 @@ struct test_softmax : verify_program<test_softmax<Axis, T>>
};
template struct test_softmax<0, migraphx::shape::float_type>;
template struct test_softmax<1, migraphx::shape::float_type>;
template struct test_softmax<2, migraphx::shape::float_type>;
template struct test_softmax<3, migraphx::shape::float_type>;
template struct test_softmax<1, migraphx::shape::double_type>;
template struct test_softmax<3, migraphx::shape::double_type>;
// template struct test_softmax<0, migraphx::shape::half_type>;
// template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>;
struct test_conv : verify_program<test_conv>
{
......@@ -3349,12 +3350,12 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul
};
template <int Axis>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis>>
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1025, 4, 1025, 6}};
migraphx::shape s{migraphx::shape::float_type, {3}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
......@@ -3362,18 +3363,15 @@ struct test_logsoftmax : verify_program<test_logsoftmax<Axis>>
}
};
template struct test_logsoftmax<0>;
template struct test_logsoftmax<1>;
template struct test_logsoftmax<2>;
template struct test_logsoftmax<3>;
template struct test_logsoftmax_1<0>;
template <int Axis>
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
template <int Axis, migraphx::shape::type_t T>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
......@@ -3381,7 +3379,17 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
}
};
template struct test_logsoftmax_1<0>;
template struct test_logsoftmax<0, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::float_type>;
template struct test_logsoftmax<2, migraphx::shape::float_type>;
template struct test_logsoftmax<3, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::double_type>;
template struct test_logsoftmax<3, migraphx::shape::double_type>;
template struct test_logsoftmax<1, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::half_type>;
template struct test_logsoftmax<2, migraphx::shape::half_type>;
template struct test_logsoftmax<3, migraphx::shape::half_type>;
struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment